1 //===------------ BPFCheckAndAdjustIR.cpp - Check and Adjust IR -----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // Check IR and adjust IR for verifier friendly codes. 10 // The following are done for IR checking: 11 // - no relocation globals in PHI node. 12 // The following are done for IR adjustment: 13 // - remove __builtin_bpf_passthrough builtins. Target independent IR 14 // optimizations are done and those builtins can be removed. 15 // - remove llvm.bpf.getelementptr.and.load builtins. 16 // - remove llvm.bpf.getelementptr.and.store builtins. 17 // - for loads and stores with base addresses from non-zero address space 18 // cast base address to zero address space (support for BPF address spaces). 19 // 20 //===----------------------------------------------------------------------===// 21 22 #include "BPF.h" 23 #include "BPFCORE.h" 24 #include "llvm/Analysis/LoopInfo.h" 25 #include "llvm/IR/GlobalVariable.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/Instruction.h" 28 #include "llvm/IR/Instructions.h" 29 #include "llvm/IR/IntrinsicsBPF.h" 30 #include "llvm/IR/Module.h" 31 #include "llvm/IR/Type.h" 32 #include "llvm/IR/Value.h" 33 #include "llvm/Pass.h" 34 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 35 36 #define DEBUG_TYPE "bpf-check-and-opt-ir" 37 38 using namespace llvm; 39 40 namespace { 41 42 class BPFCheckAndAdjustIR final : public ModulePass { 43 bool runOnModule(Module &F) override; 44 45 public: 46 static char ID; 47 BPFCheckAndAdjustIR() : ModulePass(ID) {} 48 virtual void getAnalysisUsage(AnalysisUsage &AU) const override; 49 50 private: 51 void checkIR(Module &M); 52 bool adjustIR(Module &M); 53 bool removePassThroughBuiltin(Module &M); 54 bool removeCompareBuiltin(Module &M); 55 bool sinkMinMax(Module &M); 56 bool removeGEPBuiltins(Module &M); 57 bool insertASpaceCasts(Module &M); 58 }; 59 } // End anonymous namespace 60 61 char BPFCheckAndAdjustIR::ID = 0; 62 INITIALIZE_PASS(BPFCheckAndAdjustIR, DEBUG_TYPE, "BPF Check And Adjust IR", 63 false, false) 64 65 ModulePass *llvm::createBPFCheckAndAdjustIR() { 66 return new BPFCheckAndAdjustIR(); 67 } 68 69 void BPFCheckAndAdjustIR::checkIR(Module &M) { 70 // Ensure relocation global won't appear in PHI node 71 // This may happen if the compiler generated the following code: 72 // B1: 73 // g1 = @llvm.skb_buff:0:1... 74 // ... 75 // goto B_COMMON 76 // B2: 77 // g2 = @llvm.skb_buff:0:2... 78 // ... 79 // goto B_COMMON 80 // B_COMMON: 81 // g = PHI(g1, g2) 82 // x = load g 83 // ... 84 // If anything likes the above "g = PHI(g1, g2)", issue a fatal error. 85 for (Function &F : M) 86 for (auto &BB : F) 87 for (auto &I : BB) { 88 PHINode *PN = dyn_cast<PHINode>(&I); 89 if (!PN || PN->use_empty()) 90 continue; 91 for (int i = 0, e = PN->getNumIncomingValues(); i < e; ++i) { 92 auto *GV = dyn_cast<GlobalVariable>(PN->getIncomingValue(i)); 93 if (!GV) 94 continue; 95 if (GV->hasAttribute(BPFCoreSharedInfo::AmaAttr) || 96 GV->hasAttribute(BPFCoreSharedInfo::TypeIdAttr)) 97 report_fatal_error("relocation global in PHI node"); 98 } 99 } 100 } 101 102 bool BPFCheckAndAdjustIR::removePassThroughBuiltin(Module &M) { 103 // Remove __builtin_bpf_passthrough()'s which are used to prevent 104 // certain IR optimizations. Now major IR optimizations are done, 105 // remove them. 106 bool Changed = false; 107 CallInst *ToBeDeleted = nullptr; 108 for (Function &F : M) 109 for (auto &BB : F) 110 for (auto &I : BB) { 111 if (ToBeDeleted) { 112 ToBeDeleted->eraseFromParent(); 113 ToBeDeleted = nullptr; 114 } 115 116 auto *Call = dyn_cast<CallInst>(&I); 117 if (!Call) 118 continue; 119 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand()); 120 if (!GV) 121 continue; 122 if (!GV->getName().starts_with("llvm.bpf.passthrough")) 123 continue; 124 Changed = true; 125 Value *Arg = Call->getArgOperand(1); 126 Call->replaceAllUsesWith(Arg); 127 ToBeDeleted = Call; 128 } 129 return Changed; 130 } 131 132 bool BPFCheckAndAdjustIR::removeCompareBuiltin(Module &M) { 133 // Remove __builtin_bpf_compare()'s which are used to prevent 134 // certain IR optimizations. Now major IR optimizations are done, 135 // remove them. 136 bool Changed = false; 137 CallInst *ToBeDeleted = nullptr; 138 for (Function &F : M) 139 for (auto &BB : F) 140 for (auto &I : BB) { 141 if (ToBeDeleted) { 142 ToBeDeleted->eraseFromParent(); 143 ToBeDeleted = nullptr; 144 } 145 146 auto *Call = dyn_cast<CallInst>(&I); 147 if (!Call) 148 continue; 149 auto *GV = dyn_cast<GlobalValue>(Call->getCalledOperand()); 150 if (!GV) 151 continue; 152 if (!GV->getName().starts_with("llvm.bpf.compare")) 153 continue; 154 155 Changed = true; 156 Value *Arg0 = Call->getArgOperand(0); 157 Value *Arg1 = Call->getArgOperand(1); 158 Value *Arg2 = Call->getArgOperand(2); 159 160 auto OpVal = cast<ConstantInt>(Arg0)->getValue().getZExtValue(); 161 CmpInst::Predicate Opcode = (CmpInst::Predicate)OpVal; 162 163 auto *ICmp = new ICmpInst(Opcode, Arg1, Arg2); 164 ICmp->insertBefore(Call->getIterator()); 165 166 Call->replaceAllUsesWith(ICmp); 167 ToBeDeleted = Call; 168 } 169 return Changed; 170 } 171 172 struct MinMaxSinkInfo { 173 ICmpInst *ICmp; 174 Value *Other; 175 ICmpInst::Predicate Predicate; 176 CallInst *MinMax; 177 ZExtInst *ZExt; 178 SExtInst *SExt; 179 180 MinMaxSinkInfo(ICmpInst *ICmp, Value *Other, ICmpInst::Predicate Predicate) 181 : ICmp(ICmp), Other(Other), Predicate(Predicate), MinMax(nullptr), 182 ZExt(nullptr), SExt(nullptr) {} 183 }; 184 185 static bool sinkMinMaxInBB(BasicBlock &BB, 186 const std::function<bool(Instruction *)> &Filter) { 187 // Check if V is: 188 // (fn %a %b) or (ext (fn %a %b)) 189 // Where: 190 // ext := sext | zext 191 // fn := smin | umin | smax | umax 192 auto IsMinMaxCall = [=](Value *V, MinMaxSinkInfo &Info) { 193 if (auto *ZExt = dyn_cast<ZExtInst>(V)) { 194 V = ZExt->getOperand(0); 195 Info.ZExt = ZExt; 196 } else if (auto *SExt = dyn_cast<SExtInst>(V)) { 197 V = SExt->getOperand(0); 198 Info.SExt = SExt; 199 } 200 201 auto *Call = dyn_cast<CallInst>(V); 202 if (!Call) 203 return false; 204 205 auto *Called = dyn_cast<Function>(Call->getCalledOperand()); 206 if (!Called) 207 return false; 208 209 switch (Called->getIntrinsicID()) { 210 case Intrinsic::smin: 211 case Intrinsic::umin: 212 case Intrinsic::smax: 213 case Intrinsic::umax: 214 break; 215 default: 216 return false; 217 } 218 219 if (!Filter(Call)) 220 return false; 221 222 Info.MinMax = Call; 223 224 return true; 225 }; 226 227 auto ZeroOrSignExtend = [](IRBuilder<> &Builder, Value *V, 228 MinMaxSinkInfo &Info) { 229 if (Info.SExt) { 230 if (Info.SExt->getType() == V->getType()) 231 return V; 232 return Builder.CreateSExt(V, Info.SExt->getType()); 233 } 234 if (Info.ZExt) { 235 if (Info.ZExt->getType() == V->getType()) 236 return V; 237 return Builder.CreateZExt(V, Info.ZExt->getType()); 238 } 239 return V; 240 }; 241 242 bool Changed = false; 243 SmallVector<MinMaxSinkInfo, 2> SinkList; 244 245 // Check BB for instructions like: 246 // insn := (icmp %a (fn ...)) | (icmp (fn ...) %a) 247 // 248 // Where: 249 // fn := min | max | (sext (min ...)) | (sext (max ...)) 250 // 251 // Put such instructions to SinkList. 252 for (Instruction &I : BB) { 253 ICmpInst *ICmp = dyn_cast<ICmpInst>(&I); 254 if (!ICmp) 255 continue; 256 if (!ICmp->isRelational()) 257 continue; 258 MinMaxSinkInfo First(ICmp, ICmp->getOperand(1), 259 ICmpInst::getSwappedPredicate(ICmp->getPredicate())); 260 MinMaxSinkInfo Second(ICmp, ICmp->getOperand(0), ICmp->getPredicate()); 261 bool FirstMinMax = IsMinMaxCall(ICmp->getOperand(0), First); 262 bool SecondMinMax = IsMinMaxCall(ICmp->getOperand(1), Second); 263 if (!(FirstMinMax ^ SecondMinMax)) 264 continue; 265 SinkList.push_back(FirstMinMax ? First : Second); 266 } 267 268 // Iterate SinkList and replace each (icmp ...) with corresponding 269 // `x < a && x < b` or similar expression. 270 for (auto &Info : SinkList) { 271 ICmpInst *ICmp = Info.ICmp; 272 CallInst *MinMax = Info.MinMax; 273 Intrinsic::ID IID = MinMax->getCalledFunction()->getIntrinsicID(); 274 ICmpInst::Predicate P = Info.Predicate; 275 if (ICmpInst::isSigned(P) && IID != Intrinsic::smin && 276 IID != Intrinsic::smax) 277 continue; 278 279 IRBuilder<> Builder(ICmp); 280 Value *X = Info.Other; 281 Value *A = ZeroOrSignExtend(Builder, MinMax->getArgOperand(0), Info); 282 Value *B = ZeroOrSignExtend(Builder, MinMax->getArgOperand(1), Info); 283 bool IsMin = IID == Intrinsic::smin || IID == Intrinsic::umin; 284 bool IsMax = IID == Intrinsic::smax || IID == Intrinsic::umax; 285 bool IsLess = ICmpInst::isLE(P) || ICmpInst::isLT(P); 286 bool IsGreater = ICmpInst::isGE(P) || ICmpInst::isGT(P); 287 assert(IsMin ^ IsMax); 288 assert(IsLess ^ IsGreater); 289 290 Value *Replacement; 291 Value *LHS = Builder.CreateICmp(P, X, A); 292 Value *RHS = Builder.CreateICmp(P, X, B); 293 if ((IsLess && IsMin) || (IsGreater && IsMax)) 294 // x < min(a, b) -> x < a && x < b 295 // x > max(a, b) -> x > a && x > b 296 Replacement = Builder.CreateLogicalAnd(LHS, RHS); 297 else 298 // x > min(a, b) -> x > a || x > b 299 // x < max(a, b) -> x < a || x < b 300 Replacement = Builder.CreateLogicalOr(LHS, RHS); 301 302 ICmp->replaceAllUsesWith(Replacement); 303 304 Instruction *ToRemove[] = {ICmp, Info.ZExt, Info.SExt, MinMax}; 305 for (Instruction *I : ToRemove) 306 if (I && I->use_empty()) 307 I->eraseFromParent(); 308 309 Changed = true; 310 } 311 312 return Changed; 313 } 314 315 // Do the following transformation: 316 // 317 // x < min(a, b) -> x < a && x < b 318 // x > min(a, b) -> x > a || x > b 319 // x < max(a, b) -> x < a || x < b 320 // x > max(a, b) -> x > a && x > b 321 // 322 // Such patterns are introduced by LICM.cpp:hoistMinMax() 323 // transformation and might lead to BPF verification failures for 324 // older kernels. 325 // 326 // To minimize "collateral" changes only do it for icmp + min/max 327 // calls when icmp is inside a loop and min/max is outside of that 328 // loop. 329 // 330 // Verification failure happens when: 331 // - RHS operand of some `icmp LHS, RHS` is replaced by some RHS1; 332 // - verifier can recognize RHS as a constant scalar in some context; 333 // - verifier can't recognize RHS1 as a constant scalar in the same 334 // context; 335 // 336 // The "constant scalar" is not a compile time constant, but a register 337 // that holds a scalar value known to verifier at some point in time 338 // during abstract interpretation. 339 // 340 // See also: 341 // https://lore.kernel.org/bpf/20230406164505.1046801-1-yhs@fb.com/ 342 bool BPFCheckAndAdjustIR::sinkMinMax(Module &M) { 343 bool Changed = false; 344 345 for (Function &F : M) { 346 if (F.isDeclaration()) 347 continue; 348 349 LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>(F).getLoopInfo(); 350 for (Loop *L : LI) 351 for (BasicBlock *BB : L->blocks()) { 352 // Filter out instructions coming from the same loop 353 Loop *BBLoop = LI.getLoopFor(BB); 354 auto OtherLoopFilter = [&](Instruction *I) { 355 return LI.getLoopFor(I->getParent()) != BBLoop; 356 }; 357 Changed |= sinkMinMaxInBB(*BB, OtherLoopFilter); 358 } 359 } 360 361 return Changed; 362 } 363 364 void BPFCheckAndAdjustIR::getAnalysisUsage(AnalysisUsage &AU) const { 365 AU.addRequired<LoopInfoWrapperPass>(); 366 } 367 368 static void unrollGEPLoad(CallInst *Call) { 369 auto [GEP, Load] = BPFPreserveStaticOffsetPass::reconstructLoad(Call); 370 GEP->insertBefore(Call->getIterator()); 371 Load->insertBefore(Call->getIterator()); 372 Call->replaceAllUsesWith(Load); 373 Call->eraseFromParent(); 374 } 375 376 static void unrollGEPStore(CallInst *Call) { 377 auto [GEP, Store] = BPFPreserveStaticOffsetPass::reconstructStore(Call); 378 GEP->insertBefore(Call->getIterator()); 379 Store->insertBefore(Call->getIterator()); 380 Call->eraseFromParent(); 381 } 382 383 static bool removeGEPBuiltinsInFunc(Function &F) { 384 SmallVector<CallInst *> GEPLoads; 385 SmallVector<CallInst *> GEPStores; 386 for (auto &BB : F) 387 for (auto &Insn : BB) 388 if (auto *Call = dyn_cast<CallInst>(&Insn)) 389 if (auto *Called = Call->getCalledFunction()) 390 switch (Called->getIntrinsicID()) { 391 case Intrinsic::bpf_getelementptr_and_load: 392 GEPLoads.push_back(Call); 393 break; 394 case Intrinsic::bpf_getelementptr_and_store: 395 GEPStores.push_back(Call); 396 break; 397 } 398 399 if (GEPLoads.empty() && GEPStores.empty()) 400 return false; 401 402 for_each(GEPLoads, unrollGEPLoad); 403 for_each(GEPStores, unrollGEPStore); 404 405 return true; 406 } 407 408 // Rewrites the following builtins: 409 // - llvm.bpf.getelementptr.and.load 410 // - llvm.bpf.getelementptr.and.store 411 // As (load (getelementptr ...)) or (store (getelementptr ...)). 412 bool BPFCheckAndAdjustIR::removeGEPBuiltins(Module &M) { 413 bool Changed = false; 414 for (auto &F : M) 415 Changed = removeGEPBuiltinsInFunc(F) || Changed; 416 return Changed; 417 } 418 419 // Wrap ToWrap with cast to address space zero: 420 // - if ToWrap is a getelementptr, 421 // wrap it's base pointer instead and return a copy; 422 // - if ToWrap is Instruction, insert address space cast 423 // immediately after ToWrap; 424 // - if ToWrap is not an Instruction (function parameter 425 // or a global value), insert address space cast at the 426 // beginning of the Function F; 427 // - use Cache to avoid inserting too many casts; 428 static Value *aspaceWrapValue(DenseMap<Value *, Value *> &Cache, Function *F, 429 Value *ToWrap) { 430 auto It = Cache.find(ToWrap); 431 if (It != Cache.end()) 432 return It->getSecond(); 433 434 if (auto *GEP = dyn_cast<GetElementPtrInst>(ToWrap)) { 435 Value *Ptr = GEP->getPointerOperand(); 436 Value *WrappedPtr = aspaceWrapValue(Cache, F, Ptr); 437 auto *GEPTy = cast<PointerType>(GEP->getType()); 438 auto *NewGEP = GEP->clone(); 439 NewGEP->insertAfter(GEP->getIterator()); 440 NewGEP->mutateType(PointerType::getUnqual(GEPTy->getContext())); 441 NewGEP->setOperand(GEP->getPointerOperandIndex(), WrappedPtr); 442 NewGEP->setName(GEP->getName()); 443 Cache[ToWrap] = NewGEP; 444 return NewGEP; 445 } 446 447 IRBuilder IB(F->getContext()); 448 if (Instruction *InsnPtr = dyn_cast<Instruction>(ToWrap)) 449 IB.SetInsertPoint(*InsnPtr->getInsertionPointAfterDef()); 450 else 451 IB.SetInsertPoint(F->getEntryBlock().getFirstInsertionPt()); 452 auto *ASZeroPtrTy = IB.getPtrTy(0); 453 auto *ACast = IB.CreateAddrSpaceCast(ToWrap, ASZeroPtrTy, ToWrap->getName()); 454 Cache[ToWrap] = ACast; 455 return ACast; 456 } 457 458 // Wrap a pointer operand OpNum of instruction I 459 // with cast to address space zero 460 static void aspaceWrapOperand(DenseMap<Value *, Value *> &Cache, Instruction *I, 461 unsigned OpNum) { 462 Value *OldOp = I->getOperand(OpNum); 463 if (OldOp->getType()->getPointerAddressSpace() == 0) 464 return; 465 466 Value *NewOp = aspaceWrapValue(Cache, I->getFunction(), OldOp); 467 I->setOperand(OpNum, NewOp); 468 // Check if there are any remaining users of old GEP, 469 // delete those w/o users 470 for (;;) { 471 auto *OldGEP = dyn_cast<GetElementPtrInst>(OldOp); 472 if (!OldGEP) 473 break; 474 if (!OldGEP->use_empty()) 475 break; 476 OldOp = OldGEP->getPointerOperand(); 477 OldGEP->eraseFromParent(); 478 } 479 } 480 481 // Support for BPF address spaces: 482 // - for each function in the module M, update pointer operand of 483 // each memory access instruction (load/store/cmpxchg/atomicrmw) 484 // by casting it from non-zero address space to zero address space, e.g: 485 // 486 // (load (ptr addrspace (N) %p) ...) 487 // -> (load (addrspacecast ptr addrspace (N) %p to ptr)) 488 // 489 // - assign section with name .addr_space.N for globals defined in 490 // non-zero address space N 491 bool BPFCheckAndAdjustIR::insertASpaceCasts(Module &M) { 492 bool Changed = false; 493 for (Function &F : M) { 494 DenseMap<Value *, Value *> CastsCache; 495 for (BasicBlock &BB : F) { 496 for (Instruction &I : BB) { 497 unsigned PtrOpNum; 498 499 if (auto *LD = dyn_cast<LoadInst>(&I)) 500 PtrOpNum = LD->getPointerOperandIndex(); 501 else if (auto *ST = dyn_cast<StoreInst>(&I)) 502 PtrOpNum = ST->getPointerOperandIndex(); 503 else if (auto *CmpXchg = dyn_cast<AtomicCmpXchgInst>(&I)) 504 PtrOpNum = CmpXchg->getPointerOperandIndex(); 505 else if (auto *RMW = dyn_cast<AtomicRMWInst>(&I)) 506 PtrOpNum = RMW->getPointerOperandIndex(); 507 else 508 continue; 509 510 aspaceWrapOperand(CastsCache, &I, PtrOpNum); 511 } 512 } 513 Changed |= !CastsCache.empty(); 514 } 515 // Merge all globals within same address space into single 516 // .addr_space.<addr space no> section 517 for (GlobalVariable &G : M.globals()) { 518 if (G.getAddressSpace() == 0 || G.hasSection()) 519 continue; 520 SmallString<16> SecName; 521 raw_svector_ostream OS(SecName); 522 OS << ".addr_space." << G.getAddressSpace(); 523 G.setSection(SecName); 524 // Prevent having separate section for constants 525 G.setConstant(false); 526 } 527 return Changed; 528 } 529 530 bool BPFCheckAndAdjustIR::adjustIR(Module &M) { 531 bool Changed = removePassThroughBuiltin(M); 532 Changed = removeCompareBuiltin(M) || Changed; 533 Changed = sinkMinMax(M) || Changed; 534 Changed = removeGEPBuiltins(M) || Changed; 535 Changed = insertASpaceCasts(M) || Changed; 536 return Changed; 537 } 538 539 bool BPFCheckAndAdjustIR::runOnModule(Module &M) { 540 checkIR(M); 541 return adjustIR(M); 542 } 543