1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements the interface to tear out a code region, such as an 11 // individual loop or a parallel section, into a new function, replacing it with 12 // a call to the new function. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Utils/CodeExtractor.h" 17 #include "llvm/Constants.h" 18 #include "llvm/DerivedTypes.h" 19 #include "llvm/Instructions.h" 20 #include "llvm/Intrinsics.h" 21 #include "llvm/LLVMContext.h" 22 #include "llvm/Module.h" 23 #include "llvm/Pass.h" 24 #include "llvm/Analysis/Dominators.h" 25 #include "llvm/Analysis/LoopInfo.h" 26 #include "llvm/Analysis/Verifier.h" 27 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 28 #include "llvm/Support/CommandLine.h" 29 #include "llvm/Support/Debug.h" 30 #include "llvm/Support/ErrorHandling.h" 31 #include "llvm/Support/raw_ostream.h" 32 #include "llvm/ADT/SetVector.h" 33 #include "llvm/ADT/StringExtras.h" 34 #include <algorithm> 35 #include <set> 36 using namespace llvm; 37 38 // Provide a command-line option to aggregate function arguments into a struct 39 // for functions produced by the code extractor. This is useful when converting 40 // extracted functions to pthread-based code, as only one argument (void*) can 41 // be passed in to pthread_create(). 42 static cl::opt<bool> 43 AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, 44 cl::desc("Aggregate arguments to code-extracted functions")); 45 46 /// \brief Test whether a block is valid for extraction. 47 static bool isBlockValidForExtraction(const BasicBlock &BB) { 48 // Landing pads must be in the function where they were inserted for cleanup. 49 if (BB.isLandingPad()) 50 return false; 51 52 // Don't hoist code containing allocas, invokes, or vastarts. 53 for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) { 54 if (isa<AllocaInst>(I) || isa<InvokeInst>(I)) 55 return false; 56 if (const CallInst *CI = dyn_cast<CallInst>(I)) 57 if (const Function *F = CI->getCalledFunction()) 58 if (F->getIntrinsicID() == Intrinsic::vastart) 59 return false; 60 } 61 62 return true; 63 } 64 65 /// \brief Build a set of blocks to extract if the input blocks are viable. 66 static SetVector<BasicBlock *> 67 buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs) { 68 SetVector<BasicBlock *> Result; 69 70 assert(!BBs.empty()); 71 72 // Loop over the blocks, adding them to our set-vector, and aborting with an 73 // empty set if we encounter invalid blocks. 74 for (ArrayRef<BasicBlock *>::iterator I = BBs.begin(), E = BBs.end(); 75 I != E; ++I) { 76 if (!Result.insert(*I)) 77 llvm_unreachable("Repeated basic blocks in extraction input"); 78 79 if (!isBlockValidForExtraction(**I)) { 80 Result.clear(); 81 return Result; 82 } 83 } 84 85 #ifndef NDEBUG 86 for (ArrayRef<BasicBlock *>::iterator I = llvm::next(BBs.begin()), 87 E = BBs.end(); 88 I != E; ++I) 89 for (pred_iterator PI = pred_begin(*I), PE = pred_end(*I); 90 PI != PE; ++PI) 91 assert(Result.count(*PI) && 92 "No blocks in this region may have entries from outside the region" 93 " except for the first block!"); 94 #endif 95 96 return Result; 97 } 98 99 CodeExtractor::CodeExtractor(BasicBlock *BB, bool AggregateArgs) 100 : DT(0), AggregateArgs(AggregateArgs||AggregateArgsOpt), 101 Blocks(buildExtractionBlockSet(BB)), NumExitBlocks(~0U) {} 102 103 CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, 104 bool AggregateArgs) 105 : DT(DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), 106 Blocks(buildExtractionBlockSet(BBs)), NumExitBlocks(~0U) {} 107 108 CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs) 109 : DT(&DT), AggregateArgs(AggregateArgs||AggregateArgsOpt), 110 Blocks(buildExtractionBlockSet(L.getBlocks())), NumExitBlocks(~0U) {} 111 112 /// definedInRegion - Return true if the specified value is defined in the 113 /// extracted region. 114 static bool definedInRegion(const SetVector<BasicBlock *> &Blocks, Value *V) { 115 if (Instruction *I = dyn_cast<Instruction>(V)) 116 if (Blocks.count(I->getParent())) 117 return true; 118 return false; 119 } 120 121 /// definedInCaller - Return true if the specified value is defined in the 122 /// function being code extracted, but not in the region being extracted. 123 /// These values must be passed in as live-ins to the function. 124 static bool definedInCaller(const SetVector<BasicBlock *> &Blocks, Value *V) { 125 if (isa<Argument>(V)) return true; 126 if (Instruction *I = dyn_cast<Instruction>(V)) 127 if (!Blocks.count(I->getParent())) 128 return true; 129 return false; 130 } 131 132 void CodeExtractor::findInputsOutputs(ValueSet &Inputs, 133 ValueSet &Outputs) const { 134 for (SetVector<BasicBlock *>::const_iterator I = Blocks.begin(), 135 E = Blocks.end(); 136 I != E; ++I) { 137 BasicBlock *BB = *I; 138 139 // If a used value is defined outside the region, it's an input. If an 140 // instruction is used outside the region, it's an output. 141 for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); 142 II != IE; ++II) { 143 for (User::op_iterator OI = II->op_begin(), OE = II->op_end(); 144 OI != OE; ++OI) 145 if (definedInCaller(Blocks, *OI)) 146 Inputs.insert(*OI); 147 148 for (Value::use_iterator UI = II->use_begin(), UE = II->use_end(); 149 UI != UE; ++UI) 150 if (!definedInRegion(Blocks, *UI)) { 151 Outputs.insert(II); 152 break; 153 } 154 } 155 } 156 } 157 158 /// severSplitPHINodes - If a PHI node has multiple inputs from outside of the 159 /// region, we need to split the entry block of the region so that the PHI node 160 /// is easier to deal with. 161 void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { 162 unsigned NumPredsFromRegion = 0; 163 unsigned NumPredsOutsideRegion = 0; 164 165 if (Header != &Header->getParent()->getEntryBlock()) { 166 PHINode *PN = dyn_cast<PHINode>(Header->begin()); 167 if (!PN) return; // No PHI nodes. 168 169 // If the header node contains any PHI nodes, check to see if there is more 170 // than one entry from outside the region. If so, we need to sever the 171 // header block into two. 172 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 173 if (Blocks.count(PN->getIncomingBlock(i))) 174 ++NumPredsFromRegion; 175 else 176 ++NumPredsOutsideRegion; 177 178 // If there is one (or fewer) predecessor from outside the region, we don't 179 // need to do anything special. 180 if (NumPredsOutsideRegion <= 1) return; 181 } 182 183 // Otherwise, we need to split the header block into two pieces: one 184 // containing PHI nodes merging values from outside of the region, and a 185 // second that contains all of the code for the block and merges back any 186 // incoming values from inside of the region. 187 BasicBlock::iterator AfterPHIs = Header->getFirstNonPHI(); 188 BasicBlock *NewBB = Header->splitBasicBlock(AfterPHIs, 189 Header->getName()+".ce"); 190 191 // We only want to code extract the second block now, and it becomes the new 192 // header of the region. 193 BasicBlock *OldPred = Header; 194 Blocks.remove(OldPred); 195 Blocks.insert(NewBB); 196 Header = NewBB; 197 198 // Okay, update dominator sets. The blocks that dominate the new one are the 199 // blocks that dominate TIBB plus the new block itself. 200 if (DT) 201 DT->splitBlock(NewBB); 202 203 // Okay, now we need to adjust the PHI nodes and any branches from within the 204 // region to go to the new header block instead of the old header block. 205 if (NumPredsFromRegion) { 206 PHINode *PN = cast<PHINode>(OldPred->begin()); 207 // Loop over all of the predecessors of OldPred that are in the region, 208 // changing them to branch to NewBB instead. 209 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 210 if (Blocks.count(PN->getIncomingBlock(i))) { 211 TerminatorInst *TI = PN->getIncomingBlock(i)->getTerminator(); 212 TI->replaceUsesOfWith(OldPred, NewBB); 213 } 214 215 // Okay, everything within the region is now branching to the right block, we 216 // just have to update the PHI nodes now, inserting PHI nodes into NewBB. 217 for (AfterPHIs = OldPred->begin(); isa<PHINode>(AfterPHIs); ++AfterPHIs) { 218 PHINode *PN = cast<PHINode>(AfterPHIs); 219 // Create a new PHI node in the new region, which has an incoming value 220 // from OldPred of PN. 221 PHINode *NewPN = PHINode::Create(PN->getType(), 1 + NumPredsFromRegion, 222 PN->getName()+".ce", NewBB->begin()); 223 NewPN->addIncoming(PN, OldPred); 224 225 // Loop over all of the incoming value in PN, moving them to NewPN if they 226 // are from the extracted region. 227 for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) { 228 if (Blocks.count(PN->getIncomingBlock(i))) { 229 NewPN->addIncoming(PN->getIncomingValue(i), PN->getIncomingBlock(i)); 230 PN->removeIncomingValue(i); 231 --i; 232 } 233 } 234 } 235 } 236 } 237 238 void CodeExtractor::splitReturnBlocks() { 239 for (SetVector<BasicBlock *>::iterator I = Blocks.begin(), E = Blocks.end(); 240 I != E; ++I) 241 if (ReturnInst *RI = dyn_cast<ReturnInst>((*I)->getTerminator())) { 242 BasicBlock *New = (*I)->splitBasicBlock(RI, (*I)->getName()+".ret"); 243 if (DT) { 244 // Old dominates New. New node dominates all other nodes dominated 245 // by Old. 246 DomTreeNode *OldNode = DT->getNode(*I); 247 SmallVector<DomTreeNode*, 8> Children; 248 for (DomTreeNode::iterator DI = OldNode->begin(), DE = OldNode->end(); 249 DI != DE; ++DI) 250 Children.push_back(*DI); 251 252 DomTreeNode *NewNode = DT->addNewBlock(New, *I); 253 254 for (SmallVector<DomTreeNode*, 8>::iterator I = Children.begin(), 255 E = Children.end(); I != E; ++I) 256 DT->changeImmediateDominator(*I, NewNode); 257 } 258 } 259 } 260 261 /// constructFunction - make a function based on inputs and outputs, as follows: 262 /// f(in0, ..., inN, out0, ..., outN) 263 /// 264 Function *CodeExtractor::constructFunction(const ValueSet &inputs, 265 const ValueSet &outputs, 266 BasicBlock *header, 267 BasicBlock *newRootNode, 268 BasicBlock *newHeader, 269 Function *oldFunction, 270 Module *M) { 271 DEBUG(dbgs() << "inputs: " << inputs.size() << "\n"); 272 DEBUG(dbgs() << "outputs: " << outputs.size() << "\n"); 273 274 // This function returns unsigned, outputs will go back by reference. 275 switch (NumExitBlocks) { 276 case 0: 277 case 1: RetTy = Type::getVoidTy(header->getContext()); break; 278 case 2: RetTy = Type::getInt1Ty(header->getContext()); break; 279 default: RetTy = Type::getInt16Ty(header->getContext()); break; 280 } 281 282 std::vector<Type*> paramTy; 283 284 // Add the types of the input values to the function's argument list 285 for (ValueSet::const_iterator i = inputs.begin(), e = inputs.end(); 286 i != e; ++i) { 287 const Value *value = *i; 288 DEBUG(dbgs() << "value used in func: " << *value << "\n"); 289 paramTy.push_back(value->getType()); 290 } 291 292 // Add the types of the output values to the function's argument list. 293 for (ValueSet::const_iterator I = outputs.begin(), E = outputs.end(); 294 I != E; ++I) { 295 DEBUG(dbgs() << "instr used in func: " << **I << "\n"); 296 if (AggregateArgs) 297 paramTy.push_back((*I)->getType()); 298 else 299 paramTy.push_back(PointerType::getUnqual((*I)->getType())); 300 } 301 302 DEBUG(dbgs() << "Function type: " << *RetTy << " f("); 303 for (std::vector<Type*>::iterator i = paramTy.begin(), 304 e = paramTy.end(); i != e; ++i) 305 DEBUG(dbgs() << **i << ", "); 306 DEBUG(dbgs() << ")\n"); 307 308 if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { 309 PointerType *StructPtr = 310 PointerType::getUnqual(StructType::get(M->getContext(), paramTy)); 311 paramTy.clear(); 312 paramTy.push_back(StructPtr); 313 } 314 FunctionType *funcType = 315 FunctionType::get(RetTy, paramTy, false); 316 317 // Create the new function 318 Function *newFunction = Function::Create(funcType, 319 GlobalValue::InternalLinkage, 320 oldFunction->getName() + "_" + 321 header->getName(), M); 322 // If the old function is no-throw, so is the new one. 323 if (oldFunction->doesNotThrow()) 324 newFunction->setDoesNotThrow(true); 325 326 newFunction->getBasicBlockList().push_back(newRootNode); 327 328 // Create an iterator to name all of the arguments we inserted. 329 Function::arg_iterator AI = newFunction->arg_begin(); 330 331 // Rewrite all users of the inputs in the extracted region to use the 332 // arguments (or appropriate addressing into struct) instead. 333 for (unsigned i = 0, e = inputs.size(); i != e; ++i) { 334 Value *RewriteVal; 335 if (AggregateArgs) { 336 Value *Idx[2]; 337 Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext())); 338 Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i); 339 TerminatorInst *TI = newFunction->begin()->getTerminator(); 340 GetElementPtrInst *GEP = 341 GetElementPtrInst::Create(AI, Idx, "gep_" + inputs[i]->getName(), TI); 342 RewriteVal = new LoadInst(GEP, "loadgep_" + inputs[i]->getName(), TI); 343 } else 344 RewriteVal = AI++; 345 346 std::vector<User*> Users(inputs[i]->use_begin(), inputs[i]->use_end()); 347 for (std::vector<User*>::iterator use = Users.begin(), useE = Users.end(); 348 use != useE; ++use) 349 if (Instruction* inst = dyn_cast<Instruction>(*use)) 350 if (Blocks.count(inst->getParent())) 351 inst->replaceUsesOfWith(inputs[i], RewriteVal); 352 } 353 354 // Set names for input and output arguments. 355 if (!AggregateArgs) { 356 AI = newFunction->arg_begin(); 357 for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI) 358 AI->setName(inputs[i]->getName()); 359 for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI) 360 AI->setName(outputs[i]->getName()+".out"); 361 } 362 363 // Rewrite branches to basic blocks outside of the loop to new dummy blocks 364 // within the new function. This must be done before we lose track of which 365 // blocks were originally in the code region. 366 std::vector<User*> Users(header->use_begin(), header->use_end()); 367 for (unsigned i = 0, e = Users.size(); i != e; ++i) 368 // The BasicBlock which contains the branch is not in the region 369 // modify the branch target to a new block 370 if (TerminatorInst *TI = dyn_cast<TerminatorInst>(Users[i])) 371 if (!Blocks.count(TI->getParent()) && 372 TI->getParent()->getParent() == oldFunction) 373 TI->replaceUsesOfWith(header, newHeader); 374 375 return newFunction; 376 } 377 378 /// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI 379 /// that uses the value within the basic block, and return the predecessor 380 /// block associated with that use, or return 0 if none is found. 381 static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) { 382 for (Value::use_iterator UI = Used->use_begin(), 383 UE = Used->use_end(); UI != UE; ++UI) { 384 PHINode *P = dyn_cast<PHINode>(*UI); 385 if (P && P->getParent() == BB) 386 return P->getIncomingBlock(UI); 387 } 388 389 return 0; 390 } 391 392 /// emitCallAndSwitchStatement - This method sets up the caller side by adding 393 /// the call instruction, splitting any PHI nodes in the header block as 394 /// necessary. 395 void CodeExtractor:: 396 emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, 397 ValueSet &inputs, ValueSet &outputs) { 398 // Emit a call to the new function, passing in: *pointer to struct (if 399 // aggregating parameters), or plan inputs and allocated memory for outputs 400 std::vector<Value*> params, StructValues, ReloadOutputs, Reloads; 401 402 LLVMContext &Context = newFunction->getContext(); 403 404 // Add inputs as params, or to be filled into the struct 405 for (ValueSet::iterator i = inputs.begin(), e = inputs.end(); i != e; ++i) 406 if (AggregateArgs) 407 StructValues.push_back(*i); 408 else 409 params.push_back(*i); 410 411 // Create allocas for the outputs 412 for (ValueSet::iterator i = outputs.begin(), e = outputs.end(); i != e; ++i) { 413 if (AggregateArgs) { 414 StructValues.push_back(*i); 415 } else { 416 AllocaInst *alloca = 417 new AllocaInst((*i)->getType(), 0, (*i)->getName()+".loc", 418 codeReplacer->getParent()->begin()->begin()); 419 ReloadOutputs.push_back(alloca); 420 params.push_back(alloca); 421 } 422 } 423 424 AllocaInst *Struct = 0; 425 if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { 426 std::vector<Type*> ArgTypes; 427 for (ValueSet::iterator v = StructValues.begin(), 428 ve = StructValues.end(); v != ve; ++v) 429 ArgTypes.push_back((*v)->getType()); 430 431 // Allocate a struct at the beginning of this function 432 Type *StructArgTy = StructType::get(newFunction->getContext(), ArgTypes); 433 Struct = 434 new AllocaInst(StructArgTy, 0, "structArg", 435 codeReplacer->getParent()->begin()->begin()); 436 params.push_back(Struct); 437 438 for (unsigned i = 0, e = inputs.size(); i != e; ++i) { 439 Value *Idx[2]; 440 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 441 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i); 442 GetElementPtrInst *GEP = 443 GetElementPtrInst::Create(Struct, Idx, 444 "gep_" + StructValues[i]->getName()); 445 codeReplacer->getInstList().push_back(GEP); 446 StoreInst *SI = new StoreInst(StructValues[i], GEP); 447 codeReplacer->getInstList().push_back(SI); 448 } 449 } 450 451 // Emit the call to the function 452 CallInst *call = CallInst::Create(newFunction, params, 453 NumExitBlocks > 1 ? "targetBlock" : ""); 454 codeReplacer->getInstList().push_back(call); 455 456 Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); 457 unsigned FirstOut = inputs.size(); 458 if (!AggregateArgs) 459 std::advance(OutputArgBegin, inputs.size()); 460 461 // Reload the outputs passed in by reference 462 for (unsigned i = 0, e = outputs.size(); i != e; ++i) { 463 Value *Output = 0; 464 if (AggregateArgs) { 465 Value *Idx[2]; 466 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 467 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); 468 GetElementPtrInst *GEP 469 = GetElementPtrInst::Create(Struct, Idx, 470 "gep_reload_" + outputs[i]->getName()); 471 codeReplacer->getInstList().push_back(GEP); 472 Output = GEP; 473 } else { 474 Output = ReloadOutputs[i]; 475 } 476 LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload"); 477 Reloads.push_back(load); 478 codeReplacer->getInstList().push_back(load); 479 std::vector<User*> Users(outputs[i]->use_begin(), outputs[i]->use_end()); 480 for (unsigned u = 0, e = Users.size(); u != e; ++u) { 481 Instruction *inst = cast<Instruction>(Users[u]); 482 if (!Blocks.count(inst->getParent())) 483 inst->replaceUsesOfWith(outputs[i], load); 484 } 485 } 486 487 // Now we can emit a switch statement using the call as a value. 488 SwitchInst *TheSwitch = 489 SwitchInst::Create(Constant::getNullValue(Type::getInt16Ty(Context)), 490 codeReplacer, 0, codeReplacer); 491 492 // Since there may be multiple exits from the original region, make the new 493 // function return an unsigned, switch on that number. This loop iterates 494 // over all of the blocks in the extracted region, updating any terminator 495 // instructions in the to-be-extracted region that branch to blocks that are 496 // not in the region to be extracted. 497 std::map<BasicBlock*, BasicBlock*> ExitBlockMap; 498 499 unsigned switchVal = 0; 500 for (SetVector<BasicBlock*>::const_iterator i = Blocks.begin(), 501 e = Blocks.end(); i != e; ++i) { 502 TerminatorInst *TI = (*i)->getTerminator(); 503 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) 504 if (!Blocks.count(TI->getSuccessor(i))) { 505 BasicBlock *OldTarget = TI->getSuccessor(i); 506 // add a new basic block which returns the appropriate value 507 BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; 508 if (!NewTarget) { 509 // If we don't already have an exit stub for this non-extracted 510 // destination, create one now! 511 NewTarget = BasicBlock::Create(Context, 512 OldTarget->getName() + ".exitStub", 513 newFunction); 514 unsigned SuccNum = switchVal++; 515 516 Value *brVal = 0; 517 switch (NumExitBlocks) { 518 case 0: 519 case 1: break; // No value needed. 520 case 2: // Conditional branch, return a bool 521 brVal = ConstantInt::get(Type::getInt1Ty(Context), !SuccNum); 522 break; 523 default: 524 brVal = ConstantInt::get(Type::getInt16Ty(Context), SuccNum); 525 break; 526 } 527 528 ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget); 529 530 // Update the switch instruction. 531 TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), 532 SuccNum), 533 OldTarget); 534 535 // Restore values just before we exit 536 Function::arg_iterator OAI = OutputArgBegin; 537 for (unsigned out = 0, e = outputs.size(); out != e; ++out) { 538 // For an invoke, the normal destination is the only one that is 539 // dominated by the result of the invocation 540 BasicBlock *DefBlock = cast<Instruction>(outputs[out])->getParent(); 541 542 bool DominatesDef = true; 543 544 if (InvokeInst *Invoke = dyn_cast<InvokeInst>(outputs[out])) { 545 DefBlock = Invoke->getNormalDest(); 546 547 // Make sure we are looking at the original successor block, not 548 // at a newly inserted exit block, which won't be in the dominator 549 // info. 550 for (std::map<BasicBlock*, BasicBlock*>::iterator I = 551 ExitBlockMap.begin(), E = ExitBlockMap.end(); I != E; ++I) 552 if (DefBlock == I->second) { 553 DefBlock = I->first; 554 break; 555 } 556 557 // In the extract block case, if the block we are extracting ends 558 // with an invoke instruction, make sure that we don't emit a 559 // store of the invoke value for the unwind block. 560 if (!DT && DefBlock != OldTarget) 561 DominatesDef = false; 562 } 563 564 if (DT) { 565 DominatesDef = DT->dominates(DefBlock, OldTarget); 566 567 // If the output value is used by a phi in the target block, 568 // then we need to test for dominance of the phi's predecessor 569 // instead. Unfortunately, this a little complicated since we 570 // have already rewritten uses of the value to uses of the reload. 571 BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], 572 OldTarget); 573 if (pred && DT && DT->dominates(DefBlock, pred)) 574 DominatesDef = true; 575 } 576 577 if (DominatesDef) { 578 if (AggregateArgs) { 579 Value *Idx[2]; 580 Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); 581 Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), 582 FirstOut+out); 583 GetElementPtrInst *GEP = 584 GetElementPtrInst::Create(OAI, Idx, 585 "gep_" + outputs[out]->getName(), 586 NTRet); 587 new StoreInst(outputs[out], GEP, NTRet); 588 } else { 589 new StoreInst(outputs[out], OAI, NTRet); 590 } 591 } 592 // Advance output iterator even if we don't emit a store 593 if (!AggregateArgs) ++OAI; 594 } 595 } 596 597 // rewrite the original branch instruction with this new target 598 TI->setSuccessor(i, NewTarget); 599 } 600 } 601 602 // Now that we've done the deed, simplify the switch instruction. 603 Type *OldFnRetTy = TheSwitch->getParent()->getParent()->getReturnType(); 604 switch (NumExitBlocks) { 605 case 0: 606 // There are no successors (the block containing the switch itself), which 607 // means that previously this was the last part of the function, and hence 608 // this should be rewritten as a `ret' 609 610 // Check if the function should return a value 611 if (OldFnRetTy->isVoidTy()) { 612 ReturnInst::Create(Context, 0, TheSwitch); // Return void 613 } else if (OldFnRetTy == TheSwitch->getCondition()->getType()) { 614 // return what we have 615 ReturnInst::Create(Context, TheSwitch->getCondition(), TheSwitch); 616 } else { 617 // Otherwise we must have code extracted an unwind or something, just 618 // return whatever we want. 619 ReturnInst::Create(Context, 620 Constant::getNullValue(OldFnRetTy), TheSwitch); 621 } 622 623 TheSwitch->eraseFromParent(); 624 break; 625 case 1: 626 // Only a single destination, change the switch into an unconditional 627 // branch. 628 BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch); 629 TheSwitch->eraseFromParent(); 630 break; 631 case 2: 632 BranchInst::Create(TheSwitch->getSuccessor(1), TheSwitch->getSuccessor(2), 633 call, TheSwitch); 634 TheSwitch->eraseFromParent(); 635 break; 636 default: 637 // Otherwise, make the default destination of the switch instruction be one 638 // of the other successors. 639 TheSwitch->setCondition(call); 640 TheSwitch->setDefaultDest(TheSwitch->getSuccessor(NumExitBlocks)); 641 // Remove redundant case 642 TheSwitch->removeCase(SwitchInst::CaseIt(TheSwitch, NumExitBlocks-1)); 643 break; 644 } 645 } 646 647 void CodeExtractor::moveCodeToFunction(Function *newFunction) { 648 Function *oldFunc = (*Blocks.begin())->getParent(); 649 Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); 650 Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); 651 652 for (SetVector<BasicBlock*>::const_iterator i = Blocks.begin(), 653 e = Blocks.end(); i != e; ++i) { 654 // Delete the basic block from the old function, and the list of blocks 655 oldBlocks.remove(*i); 656 657 // Insert this basic block into the new function 658 newBlocks.push_back(*i); 659 } 660 } 661 662 Function *CodeExtractor::extractCodeRegion() { 663 if (!isEligible()) 664 return 0; 665 666 ValueSet inputs, outputs; 667 668 // Assumption: this is a single-entry code region, and the header is the first 669 // block in the region. 670 BasicBlock *header = *Blocks.begin(); 671 672 // If we have to split PHI nodes or the entry block, do so now. 673 severSplitPHINodes(header); 674 675 // If we have any return instructions in the region, split those blocks so 676 // that the return is not in the region. 677 splitReturnBlocks(); 678 679 Function *oldFunction = header->getParent(); 680 681 // This takes place of the original loop 682 BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), 683 "codeRepl", oldFunction, 684 header); 685 686 // The new function needs a root node because other nodes can branch to the 687 // head of the region, but the entry node of a function cannot have preds. 688 BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), 689 "newFuncRoot"); 690 newFuncRoot->getInstList().push_back(BranchInst::Create(header)); 691 692 // Find inputs to, outputs from the code region. 693 findInputsOutputs(inputs, outputs); 694 695 SmallPtrSet<BasicBlock *, 1> ExitBlocks; 696 for (SetVector<BasicBlock *>::iterator I = Blocks.begin(), E = Blocks.end(); 697 I != E; ++I) 698 for (succ_iterator SI = succ_begin(*I), SE = succ_end(*I); SI != SE; ++SI) 699 if (!Blocks.count(*SI)) 700 ExitBlocks.insert(*SI); 701 NumExitBlocks = ExitBlocks.size(); 702 703 // Construct new function based on inputs/outputs & add allocas for all defs. 704 Function *newFunction = constructFunction(inputs, outputs, header, 705 newFuncRoot, 706 codeReplacer, oldFunction, 707 oldFunction->getParent()); 708 709 emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); 710 711 moveCodeToFunction(newFunction); 712 713 // Loop over all of the PHI nodes in the header block, and change any 714 // references to the old incoming edge to be the new incoming edge. 715 for (BasicBlock::iterator I = header->begin(); isa<PHINode>(I); ++I) { 716 PHINode *PN = cast<PHINode>(I); 717 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 718 if (!Blocks.count(PN->getIncomingBlock(i))) 719 PN->setIncomingBlock(i, newFuncRoot); 720 } 721 722 // Look at all successors of the codeReplacer block. If any of these blocks 723 // had PHI nodes in them, we need to update the "from" block to be the code 724 // replacer, not the original block in the extracted region. 725 std::vector<BasicBlock*> Succs(succ_begin(codeReplacer), 726 succ_end(codeReplacer)); 727 for (unsigned i = 0, e = Succs.size(); i != e; ++i) 728 for (BasicBlock::iterator I = Succs[i]->begin(); isa<PHINode>(I); ++I) { 729 PHINode *PN = cast<PHINode>(I); 730 std::set<BasicBlock*> ProcessedPreds; 731 for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) 732 if (Blocks.count(PN->getIncomingBlock(i))) { 733 if (ProcessedPreds.insert(PN->getIncomingBlock(i)).second) 734 PN->setIncomingBlock(i, codeReplacer); 735 else { 736 // There were multiple entries in the PHI for this block, now there 737 // is only one, so remove the duplicated entries. 738 PN->removeIncomingValue(i, false); 739 --i; --e; 740 } 741 } 742 } 743 744 //cerr << "NEW FUNCTION: " << *newFunction; 745 // verifyFunction(*newFunction); 746 747 // cerr << "OLD FUNCTION: " << *oldFunction; 748 // verifyFunction(*oldFunction); 749 750 DEBUG(if (verifyFunction(*newFunction)) 751 report_fatal_error("verifyFunction failed!")); 752 return newFunction; 753 } 754