1 //===- llvm/unittest/IR/LegacyPassManager.cpp - Legacy PassManager tests --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This unit test exercises the legacy pass manager infrastructure. We use the 10 // old names as well to ensure that the source-level compatibility is preserved 11 // where possible. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/IR/LegacyPassManager.h" 16 #include "llvm/Analysis/CallGraph.h" 17 #include "llvm/Analysis/CallGraphSCCPass.h" 18 #include "llvm/Analysis/LoopInfo.h" 19 #include "llvm/Analysis/LoopPass.h" 20 #include "llvm/AsmParser/Parser.h" 21 #include "llvm/IR/AbstractCallSite.h" 22 #include "llvm/IR/BasicBlock.h" 23 #include "llvm/IR/CallingConv.h" 24 #include "llvm/IR/DataLayout.h" 25 #include "llvm/IR/DerivedTypes.h" 26 #include "llvm/IR/Function.h" 27 #include "llvm/IR/GlobalVariable.h" 28 #include "llvm/IR/Instructions.h" 29 #include "llvm/IR/LLVMContext.h" 30 #include "llvm/IR/Module.h" 31 #include "llvm/IR/OptBisect.h" 32 #include "llvm/InitializePasses.h" 33 #include "llvm/Support/MathExtras.h" 34 #include "llvm/Support/SourceMgr.h" 35 #include "llvm/Support/raw_ostream.h" 36 #include "llvm/Transforms/Utils/CallGraphUpdater.h" 37 #include "gtest/gtest.h" 38 39 using namespace llvm; 40 41 namespace llvm { 42 void initializeModuleNDMPass(PassRegistry&); 43 void initializeFPassPass(PassRegistry&); 44 void initializeCGPassPass(PassRegistry&); 45 void initializeLPassPass(PassRegistry&); 46 47 namespace { 48 // ND = no deps 49 // NM = no modifications 50 struct ModuleNDNM: public ModulePass { 51 public: 52 static char run; 53 static char ID; 54 ModuleNDNM() : ModulePass(ID) { } 55 bool runOnModule(Module &M) override { 56 run++; 57 return false; 58 } 59 void getAnalysisUsage(AnalysisUsage &AU) const override { 60 AU.setPreservesAll(); 61 } 62 }; 63 char ModuleNDNM::ID=0; 64 char ModuleNDNM::run=0; 65 66 struct ModuleNDM : public ModulePass { 67 public: 68 static char run; 69 static char ID; 70 ModuleNDM() : ModulePass(ID) {} 71 bool runOnModule(Module &M) override { 72 run++; 73 return true; 74 } 75 }; 76 char ModuleNDM::ID=0; 77 char ModuleNDM::run=0; 78 79 struct ModuleNDM2 : public ModulePass { 80 public: 81 static char run; 82 static char ID; 83 ModuleNDM2() : ModulePass(ID) {} 84 bool runOnModule(Module &M) override { 85 run++; 86 return true; 87 } 88 }; 89 char ModuleNDM2::ID=0; 90 char ModuleNDM2::run=0; 91 92 struct ModuleDNM : public ModulePass { 93 public: 94 static char run; 95 static char ID; 96 ModuleDNM() : ModulePass(ID) { 97 initializeModuleNDMPass(*PassRegistry::getPassRegistry()); 98 } 99 bool runOnModule(Module &M) override { 100 run++; 101 return false; 102 } 103 void getAnalysisUsage(AnalysisUsage &AU) const override { 104 AU.addRequired<ModuleNDM>(); 105 AU.setPreservesAll(); 106 } 107 }; 108 char ModuleDNM::ID=0; 109 char ModuleDNM::run=0; 110 111 template<typename P> 112 struct PassTestBase : public P { 113 protected: 114 static int runc; 115 static bool initialized; 116 static bool finalized; 117 int allocated; 118 void run() { 119 EXPECT_TRUE(initialized); 120 EXPECT_FALSE(finalized); 121 EXPECT_EQ(0, allocated); 122 allocated++; 123 runc++; 124 } 125 public: 126 static char ID; 127 static void finishedOK(int run) { 128 EXPECT_GT(runc, 0); 129 EXPECT_TRUE(initialized); 130 EXPECT_TRUE(finalized); 131 EXPECT_EQ(run, runc); 132 } 133 PassTestBase() : P(ID), allocated(0) { 134 initialized = false; 135 finalized = false; 136 runc = 0; 137 } 138 139 void releaseMemory() override { 140 EXPECT_GT(runc, 0); 141 EXPECT_GT(allocated, 0); 142 allocated--; 143 } 144 }; 145 template<typename P> char PassTestBase<P>::ID; 146 template<typename P> int PassTestBase<P>::runc; 147 template<typename P> bool PassTestBase<P>::initialized; 148 template<typename P> bool PassTestBase<P>::finalized; 149 150 template<typename T, typename P> 151 struct PassTest : public PassTestBase<P> { 152 public: 153 #ifndef _MSC_VER // MSVC complains that Pass is not base class. 154 using llvm::Pass::doInitialization; 155 using llvm::Pass::doFinalization; 156 #endif 157 bool doInitialization(T &t) override { 158 EXPECT_FALSE(PassTestBase<P>::initialized); 159 PassTestBase<P>::initialized = true; 160 return false; 161 } 162 bool doFinalization(T &t) override { 163 EXPECT_FALSE(PassTestBase<P>::finalized); 164 PassTestBase<P>::finalized = true; 165 EXPECT_EQ(0, PassTestBase<P>::allocated); 166 return false; 167 } 168 }; 169 170 struct CGPass : public PassTest<CallGraph, CallGraphSCCPass> { 171 public: 172 CGPass() { 173 initializeCGPassPass(*PassRegistry::getPassRegistry()); 174 } 175 bool runOnSCC(CallGraphSCC &SCMM) override { 176 run(); 177 return false; 178 } 179 }; 180 181 struct FPass : public PassTest<Module, FunctionPass> { 182 public: 183 bool runOnFunction(Function &F) override { 184 // FIXME: PR4112 185 // EXPECT_TRUE(getAnalysisIfAvailable<DataLayout>()); 186 run(); 187 return false; 188 } 189 }; 190 191 struct LPass : public PassTestBase<LoopPass> { 192 private: 193 static int initcount; 194 static int fincount; 195 public: 196 LPass() { 197 initializeLPassPass(*PassRegistry::getPassRegistry()); 198 initcount = 0; fincount=0; 199 EXPECT_FALSE(initialized); 200 } 201 static void finishedOK(int run, int finalized) { 202 PassTestBase<LoopPass>::finishedOK(run); 203 EXPECT_EQ(run, initcount); 204 EXPECT_EQ(finalized, fincount); 205 } 206 using llvm::Pass::doInitialization; 207 using llvm::Pass::doFinalization; 208 bool doInitialization(Loop* L, LPPassManager &LPM) override { 209 initialized = true; 210 initcount++; 211 return false; 212 } 213 bool runOnLoop(Loop *L, LPPassManager &LPM) override { 214 run(); 215 return false; 216 } 217 bool doFinalization() override { 218 fincount++; 219 finalized = true; 220 return false; 221 } 222 }; 223 int LPass::initcount=0; 224 int LPass::fincount=0; 225 226 struct OnTheFlyTest: public ModulePass { 227 public: 228 static char ID; 229 OnTheFlyTest() : ModulePass(ID) { 230 initializeFPassPass(*PassRegistry::getPassRegistry()); 231 } 232 bool runOnModule(Module &M) override { 233 for (Module::iterator I=M.begin(),E=M.end(); I != E; ++I) { 234 Function &F = *I; 235 { 236 SCOPED_TRACE("Running on the fly function pass"); 237 getAnalysis<FPass>(F); 238 } 239 } 240 return false; 241 } 242 void getAnalysisUsage(AnalysisUsage &AU) const override { 243 AU.addRequired<FPass>(); 244 } 245 }; 246 char OnTheFlyTest::ID=0; 247 248 TEST(PassManager, RunOnce) { 249 LLVMContext Context; 250 Module M("test-once", Context); 251 struct ModuleNDNM *mNDNM = new ModuleNDNM(); 252 struct ModuleDNM *mDNM = new ModuleDNM(); 253 struct ModuleNDM *mNDM = new ModuleNDM(); 254 struct ModuleNDM2 *mNDM2 = new ModuleNDM2(); 255 256 mNDM->run = mNDNM->run = mDNM->run = mNDM2->run = 0; 257 258 legacy::PassManager Passes; 259 Passes.add(mNDM2); 260 Passes.add(mNDM); 261 Passes.add(mNDNM); 262 Passes.add(mDNM); 263 264 Passes.run(M); 265 // each pass must be run exactly once, since nothing invalidates them 266 EXPECT_EQ(1, mNDM->run); 267 EXPECT_EQ(1, mNDNM->run); 268 EXPECT_EQ(1, mDNM->run); 269 EXPECT_EQ(1, mNDM2->run); 270 } 271 272 TEST(PassManager, ReRun) { 273 LLVMContext Context; 274 Module M("test-rerun", Context); 275 struct ModuleNDNM *mNDNM = new ModuleNDNM(); 276 struct ModuleDNM *mDNM = new ModuleDNM(); 277 struct ModuleNDM *mNDM = new ModuleNDM(); 278 struct ModuleNDM2 *mNDM2 = new ModuleNDM2(); 279 280 mNDM->run = mNDNM->run = mDNM->run = mNDM2->run = 0; 281 282 legacy::PassManager Passes; 283 Passes.add(mNDM); 284 Passes.add(mNDNM); 285 Passes.add(mNDM2);// invalidates mNDM needed by mDNM 286 Passes.add(mDNM); 287 288 Passes.run(M); 289 // Some passes must be rerun because a pass that modified the 290 // module/function was run in between 291 EXPECT_EQ(2, mNDM->run); 292 EXPECT_EQ(1, mNDNM->run); 293 EXPECT_EQ(1, mNDM2->run); 294 EXPECT_EQ(1, mDNM->run); 295 } 296 297 Module *makeLLVMModule(LLVMContext &Context); 298 299 template<typename T> 300 void MemoryTestHelper(int run) { 301 LLVMContext Context; 302 std::unique_ptr<Module> M(makeLLVMModule(Context)); 303 T *P = new T(); 304 legacy::PassManager Passes; 305 Passes.add(P); 306 Passes.run(*M); 307 T::finishedOK(run); 308 } 309 310 template<typename T> 311 void MemoryTestHelper(int run, int N) { 312 LLVMContext Context; 313 Module *M = makeLLVMModule(Context); 314 T *P = new T(); 315 legacy::PassManager Passes; 316 Passes.add(P); 317 Passes.run(*M); 318 T::finishedOK(run, N); 319 delete M; 320 } 321 322 TEST(PassManager, Memory) { 323 // SCC#1: test1->test2->test3->test1 324 // SCC#2: test4 325 // SCC#3: indirect call node 326 { 327 SCOPED_TRACE("Callgraph pass"); 328 MemoryTestHelper<CGPass>(3); 329 } 330 331 { 332 SCOPED_TRACE("Function pass"); 333 MemoryTestHelper<FPass>(4);// 4 functions 334 } 335 336 { 337 SCOPED_TRACE("Loop pass"); 338 MemoryTestHelper<LPass>(2, 1); //2 loops, 1 function 339 } 340 341 } 342 343 TEST(PassManager, MemoryOnTheFly) { 344 LLVMContext Context; 345 Module *M = makeLLVMModule(Context); 346 { 347 SCOPED_TRACE("Running OnTheFlyTest"); 348 struct OnTheFlyTest *O = new OnTheFlyTest(); 349 legacy::PassManager Passes; 350 Passes.add(O); 351 Passes.run(*M); 352 353 FPass::finishedOK(4); 354 } 355 delete M; 356 } 357 358 // Skips or runs optional passes. 359 struct CustomOptPassGate : public OptPassGate { 360 bool Skip; 361 CustomOptPassGate(bool Skip) : Skip(Skip) { } 362 bool shouldRunPass(const Pass *P, StringRef IRDescription) override { 363 if (P->getPassKind() == PT_Module) 364 return !Skip; 365 return OptPassGate::shouldRunPass(P, IRDescription); 366 } 367 bool isEnabled() const override { return true; } 368 }; 369 370 // Optional module pass. 371 struct ModuleOpt: public ModulePass { 372 char run = 0; 373 static char ID; 374 ModuleOpt() : ModulePass(ID) { } 375 bool runOnModule(Module &M) override { 376 if (!skipModule(M)) 377 run++; 378 return false; 379 } 380 }; 381 char ModuleOpt::ID=0; 382 383 TEST(PassManager, CustomOptPassGate) { 384 LLVMContext Context0; 385 LLVMContext Context1; 386 LLVMContext Context2; 387 CustomOptPassGate SkipOptionalPasses(true); 388 CustomOptPassGate RunOptionalPasses(false); 389 390 Module M0("custom-opt-bisect", Context0); 391 Module M1("custom-opt-bisect", Context1); 392 Module M2("custom-opt-bisect2", Context2); 393 struct ModuleOpt *mOpt0 = new ModuleOpt(); 394 struct ModuleOpt *mOpt1 = new ModuleOpt(); 395 struct ModuleOpt *mOpt2 = new ModuleOpt(); 396 397 mOpt0->run = mOpt1->run = mOpt2->run = 0; 398 399 legacy::PassManager Passes0; 400 legacy::PassManager Passes1; 401 legacy::PassManager Passes2; 402 403 Passes0.add(mOpt0); 404 Passes1.add(mOpt1); 405 Passes2.add(mOpt2); 406 407 Context1.setOptPassGate(SkipOptionalPasses); 408 Context2.setOptPassGate(RunOptionalPasses); 409 410 Passes0.run(M0); 411 Passes1.run(M1); 412 Passes2.run(M2); 413 414 // By default optional passes are run. 415 EXPECT_EQ(1, mOpt0->run); 416 417 // The first context skips optional passes. 418 EXPECT_EQ(0, mOpt1->run); 419 420 // The second context runs optional passes. 421 EXPECT_EQ(1, mOpt2->run); 422 } 423 424 Module *makeLLVMModule(LLVMContext &Context) { 425 // Module Construction 426 Module *mod = new Module("test-mem", Context); 427 mod->setDataLayout("e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-" 428 "i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-" 429 "a:0:64-s:64:64-f80:128:128"); 430 mod->setTargetTriple("x86_64-unknown-linux-gnu"); 431 432 // Type Definitions 433 std::vector<Type*>FuncTy_0_args; 434 FunctionType *FuncTy_0 = FunctionType::get( 435 /*Result=*/IntegerType::get(Context, 32), 436 /*Params=*/FuncTy_0_args, 437 /*isVarArg=*/false); 438 439 std::vector<Type*>FuncTy_2_args; 440 FuncTy_2_args.push_back(IntegerType::get(Context, 1)); 441 FunctionType *FuncTy_2 = FunctionType::get( 442 /*Result=*/Type::getVoidTy(Context), 443 /*Params=*/FuncTy_2_args, 444 /*isVarArg=*/false); 445 446 // Function Declarations 447 448 Function* func_test1 = Function::Create( 449 /*Type=*/FuncTy_0, 450 /*Linkage=*/GlobalValue::ExternalLinkage, 451 /*Name=*/"test1", mod); 452 func_test1->setCallingConv(CallingConv::C); 453 AttributeList func_test1_PAL; 454 func_test1->setAttributes(func_test1_PAL); 455 456 Function* func_test2 = Function::Create( 457 /*Type=*/FuncTy_0, 458 /*Linkage=*/GlobalValue::ExternalLinkage, 459 /*Name=*/"test2", mod); 460 func_test2->setCallingConv(CallingConv::C); 461 AttributeList func_test2_PAL; 462 func_test2->setAttributes(func_test2_PAL); 463 464 Function* func_test3 = Function::Create( 465 /*Type=*/FuncTy_0, 466 /*Linkage=*/GlobalValue::InternalLinkage, 467 /*Name=*/"test3", mod); 468 func_test3->setCallingConv(CallingConv::C); 469 AttributeList func_test3_PAL; 470 func_test3->setAttributes(func_test3_PAL); 471 472 Function* func_test4 = Function::Create( 473 /*Type=*/FuncTy_2, 474 /*Linkage=*/GlobalValue::ExternalLinkage, 475 /*Name=*/"test4", mod); 476 func_test4->setCallingConv(CallingConv::C); 477 AttributeList func_test4_PAL; 478 func_test4->setAttributes(func_test4_PAL); 479 480 // Global Variable Declarations 481 482 483 // Constant Definitions 484 485 // Global Variable Definitions 486 487 // Function Definitions 488 489 // Function: test1 (func_test1) 490 { 491 492 BasicBlock *label_entry = 493 BasicBlock::Create(Context, "entry", func_test1, nullptr); 494 495 // Block entry (label_entry) 496 CallInst* int32_3 = CallInst::Create(func_test2, "", label_entry); 497 int32_3->setCallingConv(CallingConv::C); 498 int32_3->setTailCall(false); 499 AttributeList int32_3_PAL; 500 int32_3->setAttributes(int32_3_PAL); 501 502 ReturnInst::Create(Context, int32_3, label_entry); 503 } 504 505 // Function: test2 (func_test2) 506 { 507 508 BasicBlock *label_entry_5 = 509 BasicBlock::Create(Context, "entry", func_test2, nullptr); 510 511 // Block entry (label_entry_5) 512 CallInst* int32_6 = CallInst::Create(func_test3, "", label_entry_5); 513 int32_6->setCallingConv(CallingConv::C); 514 int32_6->setTailCall(false); 515 AttributeList int32_6_PAL; 516 int32_6->setAttributes(int32_6_PAL); 517 518 ReturnInst::Create(Context, int32_6, label_entry_5); 519 } 520 521 // Function: test3 (func_test3) 522 { 523 524 BasicBlock *label_entry_8 = 525 BasicBlock::Create(Context, "entry", func_test3, nullptr); 526 527 // Block entry (label_entry_8) 528 CallInst* int32_9 = CallInst::Create(func_test1, "", label_entry_8); 529 int32_9->setCallingConv(CallingConv::C); 530 int32_9->setTailCall(false); 531 AttributeList int32_9_PAL; 532 int32_9->setAttributes(int32_9_PAL); 533 534 ReturnInst::Create(Context, int32_9, label_entry_8); 535 } 536 537 // Function: test4 (func_test4) 538 { 539 Function::arg_iterator args = func_test4->arg_begin(); 540 Value *int1_f = &*args++; 541 int1_f->setName("f"); 542 543 BasicBlock *label_entry_11 = 544 BasicBlock::Create(Context, "entry", func_test4, nullptr); 545 BasicBlock *label_bb = 546 BasicBlock::Create(Context, "bb", func_test4, nullptr); 547 BasicBlock *label_bb1 = 548 BasicBlock::Create(Context, "bb1", func_test4, nullptr); 549 BasicBlock *label_return = 550 BasicBlock::Create(Context, "return", func_test4, nullptr); 551 552 // Block entry (label_entry_11) 553 auto *AI = new AllocaInst(func_test3->getType(), 0, "func3ptr", 554 label_entry_11); 555 new StoreInst(func_test3, AI, label_entry_11); 556 BranchInst::Create(label_bb, label_entry_11); 557 558 // Block bb (label_bb) 559 BranchInst::Create(label_bb, label_bb1, int1_f, label_bb); 560 561 // Block bb1 (label_bb1) 562 BranchInst::Create(label_bb1, label_return, int1_f, label_bb1); 563 564 // Block return (label_return) 565 ReturnInst::Create(Context, label_return); 566 } 567 return mod; 568 } 569 570 /// Split a simple function which contains only a call and a return into two 571 /// such that the first calls the second and the second whoever was called 572 /// initially. 573 Function *splitSimpleFunction(Function &F) { 574 LLVMContext &Context = F.getContext(); 575 Function *SF = Function::Create(F.getFunctionType(), F.getLinkage(), 576 F.getName() + "b", F.getParent()); 577 F.setName(F.getName() + "a"); 578 BasicBlock *Entry = BasicBlock::Create(Context, "entry", SF, nullptr); 579 CallInst &CI = cast<CallInst>(F.getEntryBlock().front()); 580 CI.clone()->insertBefore(ReturnInst::Create(Context, Entry)); 581 CI.setCalledFunction(SF); 582 return SF; 583 } 584 585 struct CGModifierPass : public CGPass { 586 unsigned NumSCCs = 0; 587 unsigned NumFns = 0; 588 unsigned NumFnDecls = 0; 589 unsigned SetupWorked = 0; 590 unsigned NumExtCalledBefore = 0; 591 unsigned NumExtCalledAfter = 0; 592 593 CallGraphUpdater CGU; 594 595 bool runOnSCC(CallGraphSCC &SCMM) override { 596 ++NumSCCs; 597 for (CallGraphNode *N : SCMM) { 598 if (N->getFunction()){ 599 ++NumFns; 600 NumFnDecls += N->getFunction()->isDeclaration(); 601 } 602 } 603 CGPass::run(); 604 605 CallGraph &CG = const_cast<CallGraph &>(SCMM.getCallGraph()); 606 CallGraphNode *ExtCallingNode = CG.getExternalCallingNode(); 607 NumExtCalledBefore = ExtCallingNode->size(); 608 609 if (SCMM.size() <= 1) 610 return false; 611 612 CallGraphNode *N = *(SCMM.begin()); 613 Function *F = N->getFunction(); 614 Module *M = F->getParent(); 615 Function *Test1F = M->getFunction("test1"); 616 Function *Test2aF = M->getFunction("test2a"); 617 Function *Test2bF = M->getFunction("test2b"); 618 Function *Test3F = M->getFunction("test3"); 619 620 auto InSCC = [&](Function *Fn) { 621 return llvm::any_of(SCMM, [Fn](CallGraphNode *CGN) { 622 return CGN->getFunction() == Fn; 623 }); 624 }; 625 626 if (!Test1F || !Test2aF || !Test2bF || !Test3F || !InSCC(Test1F) || 627 !InSCC(Test2aF) || !InSCC(Test2bF) || !InSCC(Test3F)) 628 return false; 629 630 CallInst *CI = dyn_cast<CallInst>(&Test1F->getEntryBlock().front()); 631 if (!CI || CI->getCalledFunction() != Test2aF) 632 return false; 633 634 SetupWorked += 1; 635 636 // Create a replica of test3 and just move the blocks there. 637 Function *Test3FRepl = Function::Create( 638 /*Type=*/Test3F->getFunctionType(), 639 /*Linkage=*/GlobalValue::InternalLinkage, 640 /*Name=*/"test3repl", Test3F->getParent()); 641 while (!Test3F->empty()) { 642 BasicBlock &BB = Test3F->front(); 643 BB.removeFromParent(); 644 BB.insertInto(Test3FRepl); 645 } 646 647 CGU.initialize(CG, SCMM); 648 649 // Replace test3 with the replica. This is legal as it is actually 650 // internal and the "capturing use" is not really capturing anything. 651 CGU.replaceFunctionWith(*Test3F, *Test3FRepl); 652 Test3F->replaceAllUsesWith(Test3FRepl); 653 654 // Rewrite the call in test1 to point to the replica of 3 not test2. 655 CI->setCalledFunction(Test3FRepl); 656 657 // Delete test2a and test2b and reanalyze 1 as we changed calls inside. 658 CGU.removeFunction(*Test2aF); 659 CGU.removeFunction(*Test2bF); 660 CGU.reanalyzeFunction(*Test1F); 661 662 return true; 663 } 664 665 bool doFinalization(CallGraph &CG) override { 666 CGU.finalize(); 667 // We removed test2 and replaced the internal test3. 668 NumExtCalledAfter = CG.getExternalCallingNode()->size(); 669 return true; 670 } 671 }; 672 673 TEST(PassManager, CallGraphUpdater0) { 674 // SCC#1: test1->test2a->test2b->test3->test1 675 // SCC#2: test4 676 // SCC#3: test3 (the empty function declaration as we replaced it with 677 // test3repl when we visited SCC#1) 678 // SCC#4: test2a->test2b (the empty function declarations as we deleted 679 // these functions when we visited SCC#1) 680 // SCC#5: indirect call node 681 682 LLVMContext Context; 683 std::unique_ptr<Module> M(makeLLVMModule(Context)); 684 ASSERT_EQ(M->getFunctionList().size(), 4U); 685 Function *F = M->getFunction("test2"); 686 Function *SF = splitSimpleFunction(*F); 687 CallInst::Create(F, "", &*SF->getEntryBlock().getFirstInsertionPt()); 688 ASSERT_EQ(M->getFunctionList().size(), 5U); 689 CGModifierPass *P = new CGModifierPass(); 690 legacy::PassManager Passes; 691 Passes.add(P); 692 Passes.run(*M); 693 ASSERT_EQ(P->SetupWorked, 1U); 694 ASSERT_EQ(P->NumSCCs, 4U); 695 ASSERT_EQ(P->NumFns, 6U); 696 ASSERT_EQ(P->NumFnDecls, 1U); 697 ASSERT_EQ(M->getFunctionList().size(), 3U); 698 ASSERT_EQ(P->NumExtCalledBefore, /* test1, 2a, 2b, 3, 4 */ 5U); 699 ASSERT_EQ(P->NumExtCalledAfter, /* test1, 3repl, 4 */ 3U); 700 } 701 702 // Test for call graph SCC pass that replaces all callback call instructions 703 // with clones and updates CallGraph by calling CallGraph::replaceCallEdge() 704 // method. Test is expected to complete successfully after running pass on 705 // all SCCs in the test module. 706 struct CallbackCallsModifierPass : public CGPass { 707 bool runOnSCC(CallGraphSCC &SCC) override { 708 CGPass::run(); 709 710 CallGraph &CG = const_cast<CallGraph &>(SCC.getCallGraph()); 711 712 bool Changed = false; 713 for (CallGraphNode *CGN : SCC) { 714 Function *F = CGN->getFunction(); 715 if (!F || F->isDeclaration()) 716 continue; 717 718 SmallVector<CallBase *, 4u> Calls; 719 for (Use &U : F->uses()) { 720 AbstractCallSite ACS(&U); 721 if (!ACS || !ACS.isCallbackCall() || !ACS.isCallee(&U)) 722 continue; 723 Calls.push_back(cast<CallBase>(ACS.getInstruction())); 724 } 725 if (Calls.empty()) 726 continue; 727 728 for (CallBase *OldCB : Calls) { 729 CallGraphNode *CallerCGN = CG[OldCB->getParent()->getParent()]; 730 assert(any_of(*CallerCGN, 731 [CGN](const CallGraphNode::CallRecord &CallRecord) { 732 return CallRecord.second == CGN; 733 }) && 734 "function is not a callee"); 735 736 CallBase *NewCB = cast<CallBase>(OldCB->clone()); 737 738 NewCB->insertBefore(OldCB); 739 NewCB->takeName(OldCB); 740 741 CallerCGN->replaceCallEdge(*OldCB, *NewCB, CG[F]); 742 743 OldCB->replaceAllUsesWith(NewCB); 744 OldCB->eraseFromParent(); 745 } 746 Changed = true; 747 } 748 return Changed; 749 } 750 }; 751 752 TEST(PassManager, CallbackCallsModifier0) { 753 LLVMContext Context; 754 755 const char *IR = "define void @foo() {\n" 756 " call void @broker(void (i8*)* @callback0, i8* null)\n" 757 " call void @broker(void (i8*)* @callback1, i8* null)\n" 758 " ret void\n" 759 "}\n" 760 "\n" 761 "declare !callback !0 void @broker(void (i8*)*, i8*)\n" 762 "\n" 763 "define internal void @callback0(i8* %arg) {\n" 764 " ret void\n" 765 "}\n" 766 "\n" 767 "define internal void @callback1(i8* %arg) {\n" 768 " ret void\n" 769 "}\n" 770 "\n" 771 "!0 = !{!1}\n" 772 "!1 = !{i64 0, i64 1, i1 false}"; 773 774 SMDiagnostic Err; 775 std::unique_ptr<Module> M = parseAssemblyString(IR, Err, Context); 776 if (!M) 777 Err.print("LegacyPassManagerTest", errs()); 778 779 CallbackCallsModifierPass *P = new CallbackCallsModifierPass(); 780 legacy::PassManager Passes; 781 Passes.add(P); 782 Passes.run(*M); 783 } 784 } 785 } 786 787 INITIALIZE_PASS(ModuleNDM, "mndm", "mndm", false, false) 788 INITIALIZE_PASS_BEGIN(CGPass, "cgp","cgp", false, false) 789 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) 790 INITIALIZE_PASS_END(CGPass, "cgp","cgp", false, false) 791 INITIALIZE_PASS(FPass, "fp","fp", false, false) 792 INITIALIZE_PASS_BEGIN(LPass, "lp","lp", false, false) 793 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 794 INITIALIZE_PASS_END(LPass, "lp","lp", false, false) 795