1 //===- CodeExtractor.cpp - Pull code region into a new function -----------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file was developed by the LLVM research group and is distributed under 6 // the University of Illinois Open Source License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file implements the interface to tear out a code region, such as an 11 // individual loop or a parallel section, into a new function, replacing it with 12 // a call to the new function. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Utils/FunctionUtils.h" 17 #include "llvm/Constants.h" 18 #include "llvm/DerivedTypes.h" 19 #include "llvm/Instructions.h" 20 #include "llvm/Module.h" 21 #include "llvm/Pass.h" 22 #include "llvm/Analysis/Dominators.h" 23 #include "llvm/Analysis/LoopInfo.h" 24 #include "llvm/Analysis/Verifier.h" 25 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 26 #include "Support/Debug.h" 27 #include "Support/StringExtras.h" 28 #include <algorithm> 29 #include <set> 30 using namespace llvm; 31 32 namespace { 33 34 /// getFunctionArg - Return a pointer to F's ARGNOth argument. 35 /// 36 Argument *getFunctionArg(Function *F, unsigned argno) { 37 Function::aiterator I = F->abegin(); 38 std::advance(I, argno); 39 return I; 40 } 41 42 class CodeExtractor { 43 typedef std::vector<Value*> Values; 44 typedef std::vector<std::pair<unsigned, unsigned> > PhiValChangesTy; 45 typedef std::map<PHINode*, PhiValChangesTy> PhiVal2ArgTy; 46 PhiVal2ArgTy PhiVal2Arg; 47 std::set<BasicBlock*> BlocksToExtract; 48 DominatorSet *DS; 49 public: 50 CodeExtractor(DominatorSet *ds = 0) : DS(ds) {} 51 52 Function *ExtractCodeRegion(const std::vector<BasicBlock*> &code); 53 54 private: 55 void findInputsOutputs(Values &inputs, Values &outputs, 56 BasicBlock *newHeader, 57 BasicBlock *newRootNode); 58 59 void processPhiNodeInputs(PHINode *Phi, 60 Values &inputs, 61 BasicBlock *newHeader, 62 BasicBlock *newRootNode); 63 64 void rewritePhiNodes(Function *F, BasicBlock *newFuncRoot); 65 66 Function *constructFunction(const Values &inputs, 67 const Values &outputs, 68 BasicBlock *newRootNode, BasicBlock *newHeader, 69 Function *oldFunction, Module *M); 70 71 void moveCodeToFunction(Function *newFunction); 72 73 void emitCallAndSwitchStatement(Function *newFunction, 74 BasicBlock *newHeader, 75 Values &inputs, 76 Values &outputs); 77 78 }; 79 } 80 81 void CodeExtractor::processPhiNodeInputs(PHINode *Phi, 82 Values &inputs, 83 BasicBlock *codeReplacer, 84 BasicBlock *newFuncRoot) { 85 // Separate incoming values and BasicBlocks as internal/external. We ignore 86 // the case where both the value and BasicBlock are internal, because we don't 87 // need to do a thing. 88 std::vector<unsigned> EValEBB; 89 std::vector<unsigned> EValIBB; 90 std::vector<unsigned> IValEBB; 91 92 for (unsigned i = 0, e = Phi->getNumIncomingValues(); i != e; ++i) { 93 Value *phiVal = Phi->getIncomingValue(i); 94 if (Instruction *Inst = dyn_cast<Instruction>(phiVal)) { 95 if (BlocksToExtract.count(Inst->getParent())) { 96 if (!BlocksToExtract.count(Phi->getIncomingBlock(i))) 97 IValEBB.push_back(i); 98 } else { 99 if (BlocksToExtract.count(Phi->getIncomingBlock(i))) 100 EValIBB.push_back(i); 101 else 102 EValEBB.push_back(i); 103 } 104 } else if (Argument *Arg = dyn_cast<Argument>(phiVal)) { 105 // arguments are external 106 if (BlocksToExtract.count(Phi->getIncomingBlock(i))) 107 EValIBB.push_back(i); 108 else 109 EValEBB.push_back(i); 110 } else { 111 // Globals/Constants are internal, but considered `external' if they are 112 // coming from an external block. 113 if (!BlocksToExtract.count(Phi->getIncomingBlock(i))) 114 EValEBB.push_back(i); 115 } 116 } 117 118 // Both value and block are external. Need to group all of these, have an 119 // external phi, pass the result as an argument, and have THIS phi use that 120 // result. 121 if (EValEBB.size() > 0) { 122 if (EValEBB.size() == 1) { 123 // Now if it's coming from the newFuncRoot, it's that funky input 124 unsigned phiIdx = EValEBB[0]; 125 if (!isa<Constant>(Phi->getIncomingValue(phiIdx))) { 126 PhiVal2Arg[Phi].push_back(std::make_pair(phiIdx, inputs.size())); 127 // We can just pass this value in as argument 128 inputs.push_back(Phi->getIncomingValue(phiIdx)); 129 } 130 Phi->setIncomingBlock(phiIdx, newFuncRoot); 131 } else { 132 PHINode *externalPhi = new PHINode(Phi->getType(), "extPhi"); 133 codeReplacer->getInstList().insert(codeReplacer->begin(), externalPhi); 134 for (std::vector<unsigned>::iterator i = EValEBB.begin(), 135 e = EValEBB.end(); i != e; ++i) { 136 externalPhi->addIncoming(Phi->getIncomingValue(*i), 137 Phi->getIncomingBlock(*i)); 138 139 // We make these values invalid instead of deleting them because that 140 // would shift the indices of other values... The fixPhiNodes should 141 // clean these phi nodes up later. 142 Phi->setIncomingValue(*i, 0); 143 Phi->setIncomingBlock(*i, 0); 144 } 145 PhiVal2Arg[Phi].push_back(std::make_pair(Phi->getNumIncomingValues(), 146 inputs.size())); 147 // We can just pass this value in as argument 148 inputs.push_back(externalPhi); 149 } 150 } 151 152 // When the value is external, but block internal... just pass it in as 153 // argument, no change to phi node 154 for (std::vector<unsigned>::iterator i = EValIBB.begin(), 155 e = EValIBB.end(); i != e; ++i) { 156 // rewrite the phi input node to be an argument 157 PhiVal2Arg[Phi].push_back(std::make_pair(*i, inputs.size())); 158 inputs.push_back(Phi->getIncomingValue(*i)); 159 } 160 161 // Value internal, block external this can happen if we are extracting a part 162 // of a loop. 163 for (std::vector<unsigned>::iterator i = IValEBB.begin(), 164 e = IValEBB.end(); i != e; ++i) { 165 assert(0 && "Cannot (YET) handle internal values via external blocks"); 166 } 167 } 168 169 170 void CodeExtractor::findInputsOutputs(Values &inputs, Values &outputs, 171 BasicBlock *newHeader, 172 BasicBlock *newRootNode) { 173 for (std::set<BasicBlock*>::const_iterator ci = BlocksToExtract.begin(), 174 ce = BlocksToExtract.end(); ci != ce; ++ci) { 175 BasicBlock *BB = *ci; 176 for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { 177 // If a used value is defined outside the region, it's an input. If an 178 // instruction is used outside the region, it's an output. 179 if (PHINode *Phi = dyn_cast<PHINode>(I)) { 180 processPhiNodeInputs(Phi, inputs, newHeader, newRootNode); 181 } else { 182 // All other instructions go through the generic input finder 183 // Loop over the operands of each instruction (inputs) 184 for (User::op_iterator op = I->op_begin(), opE = I->op_end(); 185 op != opE; ++op) 186 if (Instruction *opI = dyn_cast<Instruction>(*op)) { 187 // Check if definition of this operand is within the loop 188 if (!BlocksToExtract.count(opI->getParent())) 189 inputs.push_back(opI); 190 } else if (isa<Argument>(*op)) { 191 inputs.push_back(*op); 192 } 193 } 194 195 // Consider uses of this instruction (outputs) 196 for (Value::use_iterator UI = I->use_begin(), E = I->use_end(); 197 UI != E; ++UI) 198 if (!BlocksToExtract.count(cast<Instruction>(*UI)->getParent())) { 199 outputs.push_back(I); 200 break; 201 } 202 } // for: insts 203 } // for: basic blocks 204 } 205 206 void CodeExtractor::rewritePhiNodes(Function *F, 207 BasicBlock *newFuncRoot) { 208 // Write any changes that were saved before: use function arguments as inputs 209 for (PhiVal2ArgTy::iterator i = PhiVal2Arg.begin(), e = PhiVal2Arg.end(); 210 i != e; ++i) { 211 PHINode *phi = i->first; 212 PhiValChangesTy &values = i->second; 213 for (unsigned cIdx = 0, ce = values.size(); cIdx != ce; ++cIdx) 214 { 215 unsigned phiValueIdx = values[cIdx].first, argNum = values[cIdx].second; 216 if (phiValueIdx < phi->getNumIncomingValues()) 217 phi->setIncomingValue(phiValueIdx, getFunctionArg(F, argNum)); 218 else 219 phi->addIncoming(getFunctionArg(F, argNum), newFuncRoot); 220 } 221 } 222 223 // Delete any invalid Phi node inputs that were marked as NULL previously 224 for (PhiVal2ArgTy::iterator i = PhiVal2Arg.begin(), e = PhiVal2Arg.end(); 225 i != e; ++i) { 226 PHINode *phi = i->first; 227 for (unsigned idx = 0, end = phi->getNumIncomingValues(); idx != end; ++idx) 228 { 229 if (phi->getIncomingValue(idx) == 0 && phi->getIncomingBlock(idx) == 0) { 230 phi->removeIncomingValue(idx); 231 --idx; 232 --end; 233 } 234 } 235 } 236 237 // We are done with the saved values 238 PhiVal2Arg.clear(); 239 } 240 241 242 /// constructFunction - make a function based on inputs and outputs, as follows: 243 /// f(in0, ..., inN, out0, ..., outN) 244 /// 245 Function *CodeExtractor::constructFunction(const Values &inputs, 246 const Values &outputs, 247 BasicBlock *newRootNode, 248 BasicBlock *newHeader, 249 Function *oldFunction, Module *M) { 250 DEBUG(std::cerr << "inputs: " << inputs.size() << "\n"); 251 DEBUG(std::cerr << "outputs: " << outputs.size() << "\n"); 252 BasicBlock *header = *BlocksToExtract.begin(); 253 254 // This function returns unsigned, outputs will go back by reference. 255 Type *retTy = Type::UShortTy; 256 std::vector<const Type*> paramTy; 257 258 // Add the types of the input values to the function's argument list 259 for (Values::const_iterator i = inputs.begin(), 260 e = inputs.end(); i != e; ++i) { 261 const Value *value = *i; 262 DEBUG(std::cerr << "value used in func: " << value << "\n"); 263 paramTy.push_back(value->getType()); 264 } 265 266 // Add the types of the output values to the function's argument list. 267 for (Values::const_iterator I = outputs.begin(), E = outputs.end(); 268 I != E; ++I) { 269 DEBUG(std::cerr << "instr used in func: " << *I << "\n"); 270 paramTy.push_back(PointerType::get((*I)->getType())); 271 } 272 273 DEBUG(std::cerr << "Function type: " << retTy << " f("); 274 for (std::vector<const Type*>::iterator i = paramTy.begin(), 275 e = paramTy.end(); i != e; ++i) 276 DEBUG(std::cerr << *i << ", "); 277 DEBUG(std::cerr << ")\n"); 278 279 const FunctionType *funcType = FunctionType::get(retTy, paramTy, false); 280 281 // Create the new function 282 Function *newFunction = new Function(funcType, 283 GlobalValue::InternalLinkage, 284 oldFunction->getName() + "_code", M); 285 newFunction->getBasicBlockList().push_back(newRootNode); 286 287 // Create an iterator to name all of the arguments we inserted. 288 Function::aiterator AI = newFunction->abegin(); 289 290 // Rewrite all users of the inputs in the extracted region to use the 291 // arguments instead. 292 for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI) { 293 AI->setName(inputs[i]->getName()); 294 std::vector<User*> Users(inputs[i]->use_begin(), inputs[i]->use_end()); 295 for (std::vector<User*>::iterator use = Users.begin(), useE = Users.end(); 296 use != useE; ++use) 297 if (Instruction* inst = dyn_cast<Instruction>(*use)) 298 if (BlocksToExtract.count(inst->getParent())) 299 inst->replaceUsesOfWith(inputs[i], AI); 300 } 301 302 // Set names for all of the output arguments. 303 for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI) 304 AI->setName(outputs[i]->getName()+".out"); 305 306 307 // Rewrite branches to basic blocks outside of the loop to new dummy blocks 308 // within the new function. This must be done before we lose track of which 309 // blocks were originally in the code region. 310 std::vector<User*> Users(header->use_begin(), header->use_end()); 311 for (std::vector<User*>::iterator i = Users.begin(), e = Users.end(); 312 i != e; ++i) { 313 if (BranchInst *inst = dyn_cast<BranchInst>(*i)) { 314 BasicBlock *BB = inst->getParent(); 315 if (!BlocksToExtract.count(BB) && BB->getParent() == oldFunction) { 316 // The BasicBlock which contains the branch is not in the region 317 // modify the branch target to a new block 318 inst->replaceUsesOfWith(header, newHeader); 319 } 320 } 321 } 322 323 return newFunction; 324 } 325 326 void CodeExtractor::moveCodeToFunction(Function *newFunction) { 327 Function *oldFunc = (*BlocksToExtract.begin())->getParent(); 328 Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList(); 329 Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList(); 330 331 for (std::set<BasicBlock*>::const_iterator i = BlocksToExtract.begin(), 332 e = BlocksToExtract.end(); i != e; ++i) { 333 // Delete the basic block from the old function, and the list of blocks 334 oldBlocks.remove(*i); 335 336 // Insert this basic block into the new function 337 newBlocks.push_back(*i); 338 } 339 } 340 341 void 342 CodeExtractor::emitCallAndSwitchStatement(Function *newFunction, 343 BasicBlock *codeReplacer, 344 Values &inputs, 345 Values &outputs) { 346 // Emit a call to the new function, passing allocated memory for outputs and 347 // just plain inputs for non-scalars 348 std::vector<Value*> params(inputs); 349 350 // Get an iterator to the first output argument. 351 Function::aiterator OutputArgBegin = newFunction->abegin(); 352 std::advance(OutputArgBegin, inputs.size()); 353 354 for (unsigned i = 0, e = outputs.size(); i != e; ++i) { 355 Value *Output = outputs[i]; 356 // Create allocas for scalar outputs 357 AllocaInst *alloca = 358 new AllocaInst(outputs[i]->getType(), 0, Output->getName()+".loc", 359 codeReplacer->getParent()->begin()->begin()); 360 params.push_back(alloca); 361 362 LoadInst *load = new LoadInst(alloca, Output->getName()+".reload"); 363 codeReplacer->getInstList().push_back(load); 364 std::vector<User*> Users(outputs[i]->use_begin(), outputs[i]->use_end()); 365 for (unsigned u = 0, e = Users.size(); u != e; ++u) { 366 Instruction *inst = cast<Instruction>(Users[u]); 367 if (!BlocksToExtract.count(inst->getParent())) 368 inst->replaceUsesOfWith(outputs[i], load); 369 } 370 } 371 372 CallInst *call = new CallInst(newFunction, params, "targetBlock"); 373 codeReplacer->getInstList().push_front(call); 374 375 // Now we can emit a switch statement using the call as a value. 376 SwitchInst *TheSwitch = new SwitchInst(call, codeReplacer, codeReplacer); 377 378 // Since there may be multiple exits from the original region, make the new 379 // function return an unsigned, switch on that number. This loop iterates 380 // over all of the blocks in the extracted region, updating any terminator 381 // instructions in the to-be-extracted region that branch to blocks that are 382 // not in the region to be extracted. 383 std::map<BasicBlock*, BasicBlock*> ExitBlockMap; 384 385 unsigned switchVal = 0; 386 for (std::set<BasicBlock*>::const_iterator i = BlocksToExtract.begin(), 387 e = BlocksToExtract.end(); i != e; ++i) { 388 TerminatorInst *TI = (*i)->getTerminator(); 389 for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) 390 if (!BlocksToExtract.count(TI->getSuccessor(i))) { 391 BasicBlock *OldTarget = TI->getSuccessor(i); 392 // add a new basic block which returns the appropriate value 393 BasicBlock *&NewTarget = ExitBlockMap[OldTarget]; 394 if (!NewTarget) { 395 // If we don't already have an exit stub for this non-extracted 396 // destination, create one now! 397 NewTarget = new BasicBlock(OldTarget->getName() + ".exitStub", 398 newFunction); 399 400 ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal++); 401 ReturnInst *NTRet = new ReturnInst(brVal, NewTarget); 402 403 // Update the switch instruction. 404 TheSwitch->addCase(brVal, OldTarget); 405 406 // Restore values just before we exit 407 // FIXME: Use a GetElementPtr to bunch the outputs in a struct 408 Function::aiterator OAI = OutputArgBegin; 409 for (unsigned out = 0, e = outputs.size(); out != e; ++out, ++OAI) 410 if (!DS || 411 DS->dominates(cast<Instruction>(outputs[out])->getParent(), 412 TI->getParent())) 413 new StoreInst(outputs[out], OAI, NTRet); 414 } 415 416 // rewrite the original branch instruction with this new target 417 TI->setSuccessor(i, NewTarget); 418 } 419 } 420 421 // Now that we've done the deed, make the default destination of the switch 422 // instruction be one of the exit blocks of the region. 423 if (TheSwitch->getNumSuccessors() > 1) { 424 // FIXME: this is broken w.r.t. PHI nodes, but the old code was more broken. 425 // This edge is not traversable. 426 TheSwitch->setSuccessor(0, TheSwitch->getSuccessor(1)); 427 } 428 } 429 430 431 /// ExtractRegion - Removes a loop from a function, replaces it with a call to 432 /// new function. Returns pointer to the new function. 433 /// 434 /// algorithm: 435 /// 436 /// find inputs and outputs for the region 437 /// 438 /// for inputs: add to function as args, map input instr* to arg# 439 /// for outputs: add allocas for scalars, 440 /// add to func as args, map output instr* to arg# 441 /// 442 /// rewrite func to use argument #s instead of instr* 443 /// 444 /// for each scalar output in the function: at every exit, store intermediate 445 /// computed result back into memory. 446 /// 447 Function *CodeExtractor::ExtractCodeRegion(const std::vector<BasicBlock*> &code) 448 { 449 // 1) Find inputs, outputs 450 // 2) Construct new function 451 // * Add allocas for defs, pass as args by reference 452 // * Pass in uses as args 453 // 3) Move code region, add call instr to func 454 // 455 BlocksToExtract.insert(code.begin(), code.end()); 456 457 Values inputs, outputs; 458 459 // Assumption: this is a single-entry code region, and the header is the first 460 // block in the region. 461 BasicBlock *header = code[0]; 462 for (unsigned i = 1, e = code.size(); i != e; ++i) 463 for (pred_iterator PI = pred_begin(code[i]), E = pred_end(code[i]); 464 PI != E; ++PI) 465 assert(BlocksToExtract.count(*PI) && 466 "No blocks in this region may have entries from outside the region" 467 " except for the first block!"); 468 469 Function *oldFunction = header->getParent(); 470 471 // This takes place of the original loop 472 BasicBlock *codeReplacer = new BasicBlock("codeRepl", oldFunction); 473 474 // The new function needs a root node because other nodes can branch to the 475 // head of the loop, and the root cannot have predecessors 476 BasicBlock *newFuncRoot = new BasicBlock("newFuncRoot"); 477 newFuncRoot->getInstList().push_back(new BranchInst(header)); 478 479 // Find inputs to, outputs from the code region 480 // 481 // If one of the inputs is coming from a different basic block and it's in a 482 // phi node, we need to rewrite the phi node: 483 // 484 // * All the inputs which involve basic blocks OUTSIDE of this region go into 485 // a NEW phi node that takes care of finding which value really came in. 486 // The result of this phi is passed to the function as an argument. 487 // 488 // * All the other phi values stay. 489 // 490 // FIXME: PHI nodes' incoming blocks aren't being rewritten to accomodate for 491 // blocks moving to a new function. 492 // SOLUTION: move Phi nodes out of the loop header into the codeReplacer, pass 493 // the values as parameters to the function 494 findInputsOutputs(inputs, outputs, codeReplacer, newFuncRoot); 495 496 // Step 2: Construct new function based on inputs/outputs, 497 // Add allocas for all defs 498 Function *newFunction = constructFunction(inputs, outputs, newFuncRoot, 499 codeReplacer, oldFunction, 500 oldFunction->getParent()); 501 502 rewritePhiNodes(newFunction, newFuncRoot); 503 504 emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs); 505 506 moveCodeToFunction(newFunction); 507 508 DEBUG(if (verifyFunction(*newFunction)) abort()); 509 return newFunction; 510 } 511 512 /// ExtractCodeRegion - slurp a sequence of basic blocks into a brand new 513 /// function 514 /// 515 Function* llvm::ExtractCodeRegion(DominatorSet &DS, 516 const std::vector<BasicBlock*> &code) { 517 return CodeExtractor(&DS).ExtractCodeRegion(code); 518 } 519 520 /// ExtractBasicBlock - slurp a natural loop into a brand new function 521 /// 522 Function* llvm::ExtractLoop(DominatorSet &DS, Loop *L) { 523 return CodeExtractor(&DS).ExtractCodeRegion(L->getBlocks()); 524 } 525 526 /// ExtractBasicBlock - slurp a basic block into a brand new function 527 /// 528 Function* llvm::ExtractBasicBlock(BasicBlock *BB) { 529 std::vector<BasicBlock*> Blocks; 530 Blocks.push_back(BB); 531 return CodeExtractor().ExtractCodeRegion(Blocks); 532 } 533