1 //===- CallSiteSplitting.cpp ----------------------------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements a transformation that tries to split a call-site to pass 11 // more constrained arguments if its argument is predicated in the control flow 12 // so that we can expose better context to the later passes (e.g, inliner, jump 13 // threading, or IPA-CP based function cloning, etc.). 14 // As of now we support two cases : 15 // 16 // 1) If a call site is dominated by an OR condition and if any of its arguments 17 // are predicated on this OR condition, try to split the condition with more 18 // constrained arguments. For example, in the code below, we try to split the 19 // call site since we can predicate the argument(ptr) based on the OR condition. 20 // 21 // Split from : 22 // if (!ptr || c) 23 // callee(ptr); 24 // to : 25 // if (!ptr) 26 // callee(null) // set the known constant value 27 // else if (c) 28 // callee(nonnull ptr) // set non-null attribute in the argument 29 // 30 // 2) We can also split a call-site based on constant incoming values of a PHI 31 // For example, 32 // from : 33 // Header: 34 // %c = icmp eq i32 %i1, %i2 35 // br i1 %c, label %Tail, label %TBB 36 // TBB: 37 // br label Tail% 38 // Tail: 39 // %p = phi i32 [ 0, %Header], [ 1, %TBB] 40 // call void @bar(i32 %p) 41 // to 42 // Header: 43 // %c = icmp eq i32 %i1, %i2 44 // br i1 %c, label %Tail-split0, label %TBB 45 // TBB: 46 // br label %Tail-split1 47 // Tail-split0: 48 // call void @bar(i32 0) 49 // br label %Tail 50 // Tail-split1: 51 // call void @bar(i32 1) 52 // br label %Tail 53 // Tail: 54 // %p = phi i32 [ 0, %Tail-split0 ], [ 1, %Tail-split1 ] 55 // 56 //===----------------------------------------------------------------------===// 57 58 #include "llvm/Transforms/Scalar/CallSiteSplitting.h" 59 #include "llvm/ADT/Statistic.h" 60 #include "llvm/Analysis/TargetLibraryInfo.h" 61 #include "llvm/IR/IntrinsicInst.h" 62 #include "llvm/IR/PatternMatch.h" 63 #include "llvm/Support/Debug.h" 64 #include "llvm/Transforms/Scalar.h" 65 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 66 #include "llvm/Transforms/Utils/Local.h" 67 68 using namespace llvm; 69 using namespace PatternMatch; 70 71 #define DEBUG_TYPE "callsite-splitting" 72 73 STATISTIC(NumCallSiteSplit, "Number of call-site split"); 74 75 static void addNonNullAttribute(Instruction *CallI, Instruction *&NewCallI, 76 Value *Op) { 77 if (!NewCallI) 78 NewCallI = CallI->clone(); 79 CallSite CS(NewCallI); 80 unsigned ArgNo = 0; 81 for (auto &I : CS.args()) { 82 if (&*I == Op) 83 CS.addParamAttr(ArgNo, Attribute::NonNull); 84 ++ArgNo; 85 } 86 } 87 88 static void setConstantInArgument(Instruction *CallI, Instruction *&NewCallI, 89 Value *Op, Constant *ConstValue) { 90 if (!NewCallI) 91 NewCallI = CallI->clone(); 92 CallSite CS(NewCallI); 93 unsigned ArgNo = 0; 94 for (auto &I : CS.args()) { 95 if (&*I == Op) 96 CS.setArgument(ArgNo, ConstValue); 97 ++ArgNo; 98 } 99 } 100 101 static bool createCallSitesOnOrPredicatedArgument( 102 CallSite CS, Instruction *&NewCSTakenFromHeader, 103 Instruction *&NewCSTakenFromNextCond, 104 SmallVectorImpl<BranchInst *> &BranchInsts, BasicBlock *HeaderBB) { 105 assert(BranchInsts.size() <= 2 && 106 "Unexpected number of blocks in the OR predicated condition"); 107 Instruction *Instr = CS.getInstruction(); 108 BasicBlock *CallSiteBB = Instr->getParent(); 109 TerminatorInst *HeaderTI = HeaderBB->getTerminator(); 110 bool IsCSInTakenPath = CallSiteBB == HeaderTI->getSuccessor(0); 111 112 for (unsigned I = 0, E = BranchInsts.size(); I != E; ++I) { 113 BranchInst *PBI = BranchInsts[I]; 114 assert(isa<ICmpInst>(PBI->getCondition()) && 115 "Unexpected condition in a conditional branch."); 116 ICmpInst *Cmp = cast<ICmpInst>(PBI->getCondition()); 117 Value *Arg = Cmp->getOperand(0); 118 assert(isa<Constant>(Cmp->getOperand(1)) && 119 "Expected op1 to be a constant."); 120 Constant *ConstVal = cast<Constant>(Cmp->getOperand(1)); 121 CmpInst::Predicate Pred = Cmp->getPredicate(); 122 123 if (PBI->getParent() == HeaderBB) { 124 Instruction *&CallTakenFromHeader = 125 IsCSInTakenPath ? NewCSTakenFromHeader : NewCSTakenFromNextCond; 126 Instruction *&CallUntakenFromHeader = 127 IsCSInTakenPath ? NewCSTakenFromNextCond : NewCSTakenFromHeader; 128 129 assert(Pred == ICmpInst::ICMP_EQ || 130 Pred == ICmpInst::ICMP_NE && 131 "Unexpected predicate in an OR condition"); 132 133 // Set the constant value for agruments in the call predicated based on 134 // the OR condition. 135 Instruction *&CallToSetConst = Pred == ICmpInst::ICMP_EQ 136 ? CallTakenFromHeader 137 : CallUntakenFromHeader; 138 setConstantInArgument(Instr, CallToSetConst, Arg, ConstVal); 139 140 // Add the NonNull attribute if compared with the null pointer. 141 if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) { 142 Instruction *&CallToSetAttr = Pred == ICmpInst::ICMP_EQ 143 ? CallUntakenFromHeader 144 : CallTakenFromHeader; 145 addNonNullAttribute(Instr, CallToSetAttr, Arg); 146 } 147 continue; 148 } 149 150 if (Pred == ICmpInst::ICMP_EQ) { 151 if (PBI->getSuccessor(0) == Instr->getParent()) { 152 // Set the constant value for the call taken from the second block in 153 // the OR condition. 154 setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal); 155 } else { 156 // Add the NonNull attribute if compared with the null pointer for the 157 // call taken from the second block in the OR condition. 158 if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) 159 addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg); 160 } 161 } else { 162 if (PBI->getSuccessor(0) == Instr->getParent()) { 163 // Add the NonNull attribute if compared with the null pointer for the 164 // call taken from the second block in the OR condition. 165 if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) 166 addNonNullAttribute(Instr, NewCSTakenFromNextCond, Arg); 167 } else if (Pred == ICmpInst::ICMP_NE) { 168 // Set the constant value for the call in the untaken path from the 169 // header block. 170 setConstantInArgument(Instr, NewCSTakenFromNextCond, Arg, ConstVal); 171 } else 172 llvm_unreachable("Unexpected condition"); 173 } 174 } 175 return NewCSTakenFromHeader || NewCSTakenFromNextCond; 176 } 177 178 static bool canSplitCallSite(CallSite CS) { 179 // FIXME: As of now we handle only CallInst. InvokeInst could be handled 180 // without too much effort. 181 Instruction *Instr = CS.getInstruction(); 182 if (!isa<CallInst>(Instr)) 183 return false; 184 185 // Allow splitting a call-site only when there is no instruction before the 186 // call-site in the basic block. Based on this constraint, we only clone the 187 // call instruction, and we do not move a call-site across any other 188 // instruction. 189 BasicBlock *CallSiteBB = Instr->getParent(); 190 if (Instr != CallSiteBB->getFirstNonPHI()) 191 return false; 192 193 pred_iterator PII = pred_begin(CallSiteBB); 194 pred_iterator PIE = pred_end(CallSiteBB); 195 unsigned NumPreds = std::distance(PII, PIE); 196 197 // Allow only one extra call-site. No more than two from one call-site. 198 if (NumPreds != 2) 199 return false; 200 201 // Cannot split an edge from an IndirectBrInst. 202 BasicBlock *Preds[2] = {*PII++, *PII}; 203 if (isa<IndirectBrInst>(Preds[0]->getTerminator()) || 204 isa<IndirectBrInst>(Preds[1]->getTerminator())) 205 return false; 206 207 return CallSiteBB->canSplitPredecessors(); 208 } 209 210 /// Return true if the CS is split into its new predecessors which are directly 211 /// hooked to each of its orignial predecessors pointed by PredBB1 and PredBB2. 212 /// Note that PredBB1 and PredBB2 are decided in findPredicatedArgument(), 213 /// especially for the OR predicated case where PredBB1 will point the header, 214 /// and PredBB2 will point the the second compare block. CallInst1 and CallInst2 215 /// will be the new call-sites placed in the new predecessors split for PredBB1 216 /// and PredBB2, repectively. Therefore, CallInst1 will be the call-site placed 217 /// between Header and Tail, and CallInst2 will be the call-site between TBB and 218 /// Tail. For example, in the IR below with an OR condition, the call-site can 219 /// be split 220 /// 221 /// from : 222 /// 223 /// Header: 224 /// %c = icmp eq i32* %a, null 225 /// br i1 %c %Tail, %TBB 226 /// TBB: 227 /// %c2 = icmp eq i32* %b, null 228 /// br i1 %c %Tail, %End 229 /// Tail: 230 /// %ca = call i1 @callee (i32* %a, i32* %b) 231 /// 232 /// to : 233 /// 234 /// Header: // PredBB1 is Header 235 /// %c = icmp eq i32* %a, null 236 /// br i1 %c %Tail-split1, %TBB 237 /// TBB: // PredBB2 is TBB 238 /// %c2 = icmp eq i32* %b, null 239 /// br i1 %c %Tail-split2, %End 240 /// Tail-split1: 241 /// %ca1 = call @callee (i32* null, i32* %b) // CallInst1 242 /// br %Tail 243 /// Tail-split2: 244 /// %ca2 = call @callee (i32* nonnull %a, i32* null) // CallInst2 245 /// br %Tail 246 /// Tail: 247 /// %p = phi i1 [%ca1, %Tail-split1],[%ca2, %Tail-split2] 248 /// 249 /// Note that for an OR predicated case, CallInst1 and CallInst2 should be 250 /// created with more constrained arguments in 251 /// createCallSitesOnOrPredicatedArgument(). 252 static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2, 253 Instruction *CallInst1, Instruction *CallInst2) { 254 Instruction *Instr = CS.getInstruction(); 255 BasicBlock *TailBB = Instr->getParent(); 256 assert(Instr == (TailBB->getFirstNonPHI()) && "Unexpected call-site"); 257 258 BasicBlock *SplitBlock1 = 259 SplitBlockPredecessors(TailBB, PredBB1, ".predBB1.split"); 260 BasicBlock *SplitBlock2 = 261 SplitBlockPredecessors(TailBB, PredBB2, ".predBB2.split"); 262 263 assert((SplitBlock1 && SplitBlock2) && "Unexpected new basic block split."); 264 265 if (!CallInst1) 266 CallInst1 = Instr->clone(); 267 if (!CallInst2) 268 CallInst2 = Instr->clone(); 269 270 CallInst1->insertBefore(&*SplitBlock1->getFirstInsertionPt()); 271 CallInst2->insertBefore(&*SplitBlock2->getFirstInsertionPt()); 272 273 CallSite CS1(CallInst1); 274 CallSite CS2(CallInst2); 275 276 // Handle PHIs used as arguments in the call-site. 277 for (auto &PI : *TailBB) { 278 PHINode *PN = dyn_cast<PHINode>(&PI); 279 if (!PN) 280 break; 281 unsigned ArgNo = 0; 282 for (auto &CI : CS.args()) { 283 if (&*CI == PN) { 284 CS1.setArgument(ArgNo, PN->getIncomingValueForBlock(SplitBlock1)); 285 CS2.setArgument(ArgNo, PN->getIncomingValueForBlock(SplitBlock2)); 286 } 287 ++ArgNo; 288 } 289 } 290 291 // Replace users of the original call with a PHI mering call-sites split. 292 if (Instr->getNumUses()) { 293 PHINode *PN = PHINode::Create(Instr->getType(), 2, "phi.call", Instr); 294 PN->addIncoming(CallInst1, SplitBlock1); 295 PN->addIncoming(CallInst2, SplitBlock2); 296 Instr->replaceAllUsesWith(PN); 297 } 298 DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); 299 DEBUG(dbgs() << " " << *CallInst1 << " in " << SplitBlock1->getName() 300 << "\n"); 301 DEBUG(dbgs() << " " << *CallInst2 << " in " << SplitBlock2->getName() 302 << "\n"); 303 Instr->eraseFromParent(); 304 NumCallSiteSplit++; 305 } 306 307 static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) { 308 assert(isa<Constant>(Cmp->getOperand(1)) && "Expected a constant operand."); 309 Value *Op0 = Cmp->getOperand(0); 310 unsigned ArgNo = 0; 311 for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; 312 ++I, ++ArgNo) { 313 // Don't consider constant or arguments that are already known non-null. 314 if (isa<Constant>(*I) || CS.paramHasAttr(ArgNo, Attribute::NonNull)) 315 continue; 316 317 if (*I == Op0) 318 return true; 319 } 320 return false; 321 } 322 323 static void findOrCondRelevantToCallArgument( 324 CallSite CS, BasicBlock *PredBB, BasicBlock *OtherPredBB, 325 SmallVectorImpl<BranchInst *> &BranchInsts, BasicBlock *&HeaderBB) { 326 auto *PBI = dyn_cast<BranchInst>(PredBB->getTerminator()); 327 if (!PBI || !PBI->isConditional()) 328 return; 329 330 if (PBI->getSuccessor(0) == OtherPredBB || 331 PBI->getSuccessor(1) == OtherPredBB) 332 if (PredBB == OtherPredBB->getSinglePredecessor()) { 333 assert(!HeaderBB && "Expect to find only a single header block"); 334 HeaderBB = PredBB; 335 } 336 337 CmpInst::Predicate Pred; 338 Value *Cond = PBI->getCondition(); 339 if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant()))) 340 return; 341 ICmpInst *Cmp = cast<ICmpInst>(Cond); 342 if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) 343 if (isCondRelevantToAnyCallArgument(Cmp, CS)) 344 BranchInsts.push_back(PBI); 345 } 346 347 // Return true if the call-site has an argument which is a PHI with only 348 // constant incoming values. 349 static bool isPredicatedOnPHI(CallSite CS) { 350 Instruction *Instr = CS.getInstruction(); 351 BasicBlock *Parent = Instr->getParent(); 352 if (Instr != Parent->getFirstNonPHI()) 353 return false; 354 355 for (auto &BI : *Parent) { 356 if (PHINode *PN = dyn_cast<PHINode>(&BI)) { 357 for (auto &I : CS.args()) 358 if (&*I == PN) { 359 assert(PN->getNumIncomingValues() == 2 && 360 "Unexpected number of incoming values"); 361 if (PN->getIncomingBlock(0) == PN->getIncomingBlock(1)) 362 return false; 363 if (PN->getIncomingValue(0) == PN->getIncomingValue(1)) 364 continue; 365 if (isa<Constant>(PN->getIncomingValue(0)) && 366 isa<Constant>(PN->getIncomingValue(1))) 367 return true; 368 } 369 } 370 break; 371 } 372 return false; 373 } 374 375 // Return true if an agument in CS is predicated on an 'or' condition. 376 // Create new call-site with arguments constrained based on the OR condition. 377 static bool findPredicatedOnOrCondition(CallSite CS, BasicBlock *PredBB1, 378 BasicBlock *PredBB2, 379 Instruction *&NewCallTakenFromHeader, 380 Instruction *&NewCallTakenFromNextCond, 381 BasicBlock *&HeaderBB) { 382 SmallVector<BranchInst *, 4> BranchInsts; 383 findOrCondRelevantToCallArgument(CS, PredBB1, PredBB2, BranchInsts, HeaderBB); 384 findOrCondRelevantToCallArgument(CS, PredBB2, PredBB1, BranchInsts, HeaderBB); 385 if (BranchInsts.empty() || !HeaderBB) 386 return false; 387 388 // If an OR condition is detected, try to create call sites with constrained 389 // arguments (e.g., NonNull attribute or constant value). 390 return createCallSitesOnOrPredicatedArgument(CS, NewCallTakenFromHeader, 391 NewCallTakenFromNextCond, 392 BranchInsts, HeaderBB); 393 } 394 395 static bool findPredicatedArgument(CallSite CS, Instruction *&CallInst1, 396 Instruction *&CallInst2, 397 BasicBlock *&PredBB1, BasicBlock *&PredBB2) { 398 BasicBlock *CallSiteBB = CS.getInstruction()->getParent(); 399 pred_iterator PII = pred_begin(CallSiteBB); 400 pred_iterator PIE = pred_end(CallSiteBB); 401 assert(std::distance(PII, PIE) == 2 && "Expect only two predecessors."); 402 BasicBlock *Preds[2] = {*PII++, *PII}; 403 BasicBlock *&HeaderBB = PredBB1; 404 if (!findPredicatedOnOrCondition(CS, Preds[0], Preds[1], CallInst1, CallInst2, 405 HeaderBB) && 406 !isPredicatedOnPHI(CS)) 407 return false; 408 409 if (!PredBB1) 410 PredBB1 = Preds[0]; 411 412 PredBB2 = PredBB1 == Preds[0] ? Preds[1] : Preds[0]; 413 return true; 414 } 415 416 static bool tryToSplitCallSite(CallSite CS) { 417 if (!CS.arg_size()) 418 return false; 419 420 BasicBlock *PredBB1 = nullptr; 421 BasicBlock *PredBB2 = nullptr; 422 Instruction *CallInst1 = nullptr; 423 Instruction *CallInst2 = nullptr; 424 if (!canSplitCallSite(CS) || 425 !findPredicatedArgument(CS, CallInst1, CallInst2, PredBB1, PredBB2)) { 426 assert(!CallInst1 && !CallInst2 && "Unexpected new call-sites cloned."); 427 return false; 428 } 429 splitCallSite(CS, PredBB1, PredBB2, CallInst1, CallInst2); 430 return true; 431 } 432 433 static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI) { 434 bool Changed = false; 435 for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE;) { 436 BasicBlock &BB = *BI++; 437 for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) { 438 Instruction *I = &*II++; 439 CallSite CS(cast<Value>(I)); 440 if (!CS || isa<IntrinsicInst>(I) || isInstructionTriviallyDead(I, &TLI)) 441 continue; 442 443 Function *Callee = CS.getCalledFunction(); 444 if (!Callee || Callee->isDeclaration()) 445 continue; 446 Changed |= tryToSplitCallSite(CS); 447 } 448 } 449 return Changed; 450 } 451 452 namespace { 453 struct CallSiteSplittingLegacyPass : public FunctionPass { 454 static char ID; 455 CallSiteSplittingLegacyPass() : FunctionPass(ID) { 456 initializeCallSiteSplittingLegacyPassPass(*PassRegistry::getPassRegistry()); 457 } 458 459 void getAnalysisUsage(AnalysisUsage &AU) const override { 460 AU.addRequired<TargetLibraryInfoWrapperPass>(); 461 FunctionPass::getAnalysisUsage(AU); 462 } 463 464 bool runOnFunction(Function &F) override { 465 if (skipFunction(F)) 466 return false; 467 468 auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); 469 return doCallSiteSplitting(F, TLI); 470 } 471 }; 472 } // namespace 473 474 char CallSiteSplittingLegacyPass::ID = 0; 475 INITIALIZE_PASS_BEGIN(CallSiteSplittingLegacyPass, "callsite-splitting", 476 "Call-site splitting", false, false) 477 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 478 INITIALIZE_PASS_END(CallSiteSplittingLegacyPass, "callsite-splitting", 479 "Call-site splitting", false, false) 480 FunctionPass *llvm::createCallSiteSplittingPass() { 481 return new CallSiteSplittingLegacyPass(); 482 } 483 484 PreservedAnalyses CallSiteSplittingPass::run(Function &F, 485 FunctionAnalysisManager &AM) { 486 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); 487 488 if (!doCallSiteSplitting(F, TLI)) 489 return PreservedAnalyses::all(); 490 PreservedAnalyses PA; 491 return PA; 492 } 493