1 //===- llvm/unittest/IR/LegacyPassManager.cpp - Legacy PassManager tests --===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This unit test exercises the legacy pass manager infrastructure. We use the 10 // old names as well to ensure that the source-level compatibility is preserved 11 // where possible. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/IR/LegacyPassManager.h" 16 #include "llvm/Analysis/CallGraphSCCPass.h" 17 #include "llvm/Analysis/LoopInfo.h" 18 #include "llvm/Analysis/LoopPass.h" 19 #include "llvm/AsmParser/Parser.h" 20 #include "llvm/IR/AbstractCallSite.h" 21 #include "llvm/IR/BasicBlock.h" 22 #include "llvm/IR/CallingConv.h" 23 #include "llvm/IR/DataLayout.h" 24 #include "llvm/IR/DerivedTypes.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/GlobalVariable.h" 27 #include "llvm/IR/Instructions.h" 28 #include "llvm/IR/LLVMContext.h" 29 #include "llvm/IR/Module.h" 30 #include "llvm/IR/OptBisect.h" 31 #include "llvm/InitializePasses.h" 32 #include "llvm/Support/MathExtras.h" 33 #include "llvm/Support/SourceMgr.h" 34 #include "llvm/Support/raw_ostream.h" 35 #include "llvm/Transforms/Utils/CallGraphUpdater.h" 36 #include "gtest/gtest.h" 37 38 using namespace llvm; 39 40 namespace llvm { 41 void initializeModuleNDMPass(PassRegistry&); 42 void initializeFPassPass(PassRegistry&); 43 void initializeCGPassPass(PassRegistry&); 44 void initializeLPassPass(PassRegistry&); 45 46 namespace { 47 // ND = no deps 48 // NM = no modifications 49 struct ModuleNDNM: public ModulePass { 50 public: 51 static char run; 52 static char ID; 53 ModuleNDNM() : ModulePass(ID) { } 54 bool runOnModule(Module &M) override { 55 run++; 56 return false; 57 } 58 void getAnalysisUsage(AnalysisUsage &AU) const override { 59 AU.setPreservesAll(); 60 } 61 }; 62 char ModuleNDNM::ID=0; 63 char ModuleNDNM::run=0; 64 65 struct ModuleNDM : public ModulePass { 66 public: 67 static char run; 68 static char ID; 69 ModuleNDM() : ModulePass(ID) {} 70 bool runOnModule(Module &M) override { 71 run++; 72 return true; 73 } 74 }; 75 char ModuleNDM::ID=0; 76 char ModuleNDM::run=0; 77 78 struct ModuleNDM2 : public ModulePass { 79 public: 80 static char run; 81 static char ID; 82 ModuleNDM2() : ModulePass(ID) {} 83 bool runOnModule(Module &M) override { 84 run++; 85 return true; 86 } 87 }; 88 char ModuleNDM2::ID=0; 89 char ModuleNDM2::run=0; 90 91 struct ModuleDNM : public ModulePass { 92 public: 93 static char run; 94 static char ID; 95 ModuleDNM() : ModulePass(ID) { 96 initializeModuleNDMPass(*PassRegistry::getPassRegistry()); 97 } 98 bool runOnModule(Module &M) override { 99 run++; 100 return false; 101 } 102 void getAnalysisUsage(AnalysisUsage &AU) const override { 103 AU.addRequired<ModuleNDM>(); 104 AU.setPreservesAll(); 105 } 106 }; 107 char ModuleDNM::ID=0; 108 char ModuleDNM::run=0; 109 110 template<typename P> 111 struct PassTestBase : public P { 112 protected: 113 static int runc; 114 static bool initialized; 115 static bool finalized; 116 int allocated; 117 void run() { 118 EXPECT_TRUE(initialized); 119 EXPECT_FALSE(finalized); 120 EXPECT_EQ(0, allocated); 121 allocated++; 122 runc++; 123 } 124 public: 125 static char ID; 126 static void finishedOK(int run) { 127 EXPECT_GT(runc, 0); 128 EXPECT_TRUE(initialized); 129 EXPECT_TRUE(finalized); 130 EXPECT_EQ(run, runc); 131 } 132 PassTestBase() : P(ID), allocated(0) { 133 initialized = false; 134 finalized = false; 135 runc = 0; 136 } 137 138 void releaseMemory() override { 139 EXPECT_GT(runc, 0); 140 EXPECT_GT(allocated, 0); 141 allocated--; 142 } 143 }; 144 template<typename P> char PassTestBase<P>::ID; 145 template<typename P> int PassTestBase<P>::runc; 146 template<typename P> bool PassTestBase<P>::initialized; 147 template<typename P> bool PassTestBase<P>::finalized; 148 149 template<typename T, typename P> 150 struct PassTest : public PassTestBase<P> { 151 public: 152 #ifndef _MSC_VER // MSVC complains that Pass is not base class. 153 using llvm::Pass::doInitialization; 154 using llvm::Pass::doFinalization; 155 #endif 156 bool doInitialization(T &t) override { 157 EXPECT_FALSE(PassTestBase<P>::initialized); 158 PassTestBase<P>::initialized = true; 159 return false; 160 } 161 bool doFinalization(T &t) override { 162 EXPECT_FALSE(PassTestBase<P>::finalized); 163 PassTestBase<P>::finalized = true; 164 EXPECT_EQ(0, PassTestBase<P>::allocated); 165 return false; 166 } 167 }; 168 169 struct CGPass : public PassTest<CallGraph, CallGraphSCCPass> { 170 public: 171 CGPass() { 172 initializeCGPassPass(*PassRegistry::getPassRegistry()); 173 } 174 bool runOnSCC(CallGraphSCC &SCMM) override { 175 run(); 176 return false; 177 } 178 }; 179 180 struct FPass : public PassTest<Module, FunctionPass> { 181 public: 182 bool runOnFunction(Function &F) override { 183 // FIXME: PR4112 184 // EXPECT_TRUE(getAnalysisIfAvailable<DataLayout>()); 185 run(); 186 return false; 187 } 188 }; 189 190 struct LPass : public PassTestBase<LoopPass> { 191 private: 192 static int initcount; 193 static int fincount; 194 public: 195 LPass() { 196 initializeLPassPass(*PassRegistry::getPassRegistry()); 197 initcount = 0; fincount=0; 198 EXPECT_FALSE(initialized); 199 } 200 static void finishedOK(int run, int finalized) { 201 PassTestBase<LoopPass>::finishedOK(run); 202 EXPECT_EQ(run, initcount); 203 EXPECT_EQ(finalized, fincount); 204 } 205 using llvm::Pass::doInitialization; 206 using llvm::Pass::doFinalization; 207 bool doInitialization(Loop* L, LPPassManager &LPM) override { 208 initialized = true; 209 initcount++; 210 return false; 211 } 212 bool runOnLoop(Loop *L, LPPassManager &LPM) override { 213 run(); 214 return false; 215 } 216 bool doFinalization() override { 217 fincount++; 218 finalized = true; 219 return false; 220 } 221 }; 222 int LPass::initcount=0; 223 int LPass::fincount=0; 224 225 struct OnTheFlyTest: public ModulePass { 226 public: 227 static char ID; 228 OnTheFlyTest() : ModulePass(ID) { 229 initializeFPassPass(*PassRegistry::getPassRegistry()); 230 } 231 bool runOnModule(Module &M) override { 232 for (Module::iterator I=M.begin(),E=M.end(); I != E; ++I) { 233 Function &F = *I; 234 { 235 SCOPED_TRACE("Running on the fly function pass"); 236 getAnalysis<FPass>(F); 237 } 238 } 239 return false; 240 } 241 void getAnalysisUsage(AnalysisUsage &AU) const override { 242 AU.addRequired<FPass>(); 243 } 244 }; 245 char OnTheFlyTest::ID=0; 246 247 TEST(PassManager, RunOnce) { 248 LLVMContext Context; 249 Module M("test-once", Context); 250 struct ModuleNDNM *mNDNM = new ModuleNDNM(); 251 struct ModuleDNM *mDNM = new ModuleDNM(); 252 struct ModuleNDM *mNDM = new ModuleNDM(); 253 struct ModuleNDM2 *mNDM2 = new ModuleNDM2(); 254 255 mNDM->run = mNDNM->run = mDNM->run = mNDM2->run = 0; 256 257 legacy::PassManager Passes; 258 Passes.add(mNDM2); 259 Passes.add(mNDM); 260 Passes.add(mNDNM); 261 Passes.add(mDNM); 262 263 Passes.run(M); 264 // each pass must be run exactly once, since nothing invalidates them 265 EXPECT_EQ(1, mNDM->run); 266 EXPECT_EQ(1, mNDNM->run); 267 EXPECT_EQ(1, mDNM->run); 268 EXPECT_EQ(1, mNDM2->run); 269 } 270 271 TEST(PassManager, ReRun) { 272 LLVMContext Context; 273 Module M("test-rerun", Context); 274 struct ModuleNDNM *mNDNM = new ModuleNDNM(); 275 struct ModuleDNM *mDNM = new ModuleDNM(); 276 struct ModuleNDM *mNDM = new ModuleNDM(); 277 struct ModuleNDM2 *mNDM2 = new ModuleNDM2(); 278 279 mNDM->run = mNDNM->run = mDNM->run = mNDM2->run = 0; 280 281 legacy::PassManager Passes; 282 Passes.add(mNDM); 283 Passes.add(mNDNM); 284 Passes.add(mNDM2);// invalidates mNDM needed by mDNM 285 Passes.add(mDNM); 286 287 Passes.run(M); 288 // Some passes must be rerun because a pass that modified the 289 // module/function was run in between 290 EXPECT_EQ(2, mNDM->run); 291 EXPECT_EQ(1, mNDNM->run); 292 EXPECT_EQ(1, mNDM2->run); 293 EXPECT_EQ(1, mDNM->run); 294 } 295 296 Module *makeLLVMModule(LLVMContext &Context); 297 298 template<typename T> 299 void MemoryTestHelper(int run) { 300 LLVMContext Context; 301 std::unique_ptr<Module> M(makeLLVMModule(Context)); 302 T *P = new T(); 303 legacy::PassManager Passes; 304 Passes.add(P); 305 Passes.run(*M); 306 T::finishedOK(run); 307 } 308 309 template<typename T> 310 void MemoryTestHelper(int run, int N) { 311 LLVMContext Context; 312 Module *M = makeLLVMModule(Context); 313 T *P = new T(); 314 legacy::PassManager Passes; 315 Passes.add(P); 316 Passes.run(*M); 317 T::finishedOK(run, N); 318 delete M; 319 } 320 321 TEST(PassManager, Memory) { 322 // SCC#1: test1->test2->test3->test1 323 // SCC#2: test4 324 // SCC#3: indirect call node 325 { 326 SCOPED_TRACE("Callgraph pass"); 327 MemoryTestHelper<CGPass>(3); 328 } 329 330 { 331 SCOPED_TRACE("Function pass"); 332 MemoryTestHelper<FPass>(4);// 4 functions 333 } 334 335 { 336 SCOPED_TRACE("Loop pass"); 337 MemoryTestHelper<LPass>(2, 1); //2 loops, 1 function 338 } 339 340 } 341 342 TEST(PassManager, MemoryOnTheFly) { 343 LLVMContext Context; 344 Module *M = makeLLVMModule(Context); 345 { 346 SCOPED_TRACE("Running OnTheFlyTest"); 347 struct OnTheFlyTest *O = new OnTheFlyTest(); 348 legacy::PassManager Passes; 349 Passes.add(O); 350 Passes.run(*M); 351 352 FPass::finishedOK(4); 353 } 354 delete M; 355 } 356 357 // Skips or runs optional passes. 358 struct CustomOptPassGate : public OptPassGate { 359 bool Skip; 360 CustomOptPassGate(bool Skip) : Skip(Skip) { } 361 bool shouldRunPass(const Pass *P, StringRef IRDescription) override { 362 if (P->getPassKind() == PT_Module) 363 return !Skip; 364 return OptPassGate::shouldRunPass(P, IRDescription); 365 } 366 bool isEnabled() const override { return true; } 367 }; 368 369 // Optional module pass. 370 struct ModuleOpt: public ModulePass { 371 char run = 0; 372 static char ID; 373 ModuleOpt() : ModulePass(ID) { } 374 bool runOnModule(Module &M) override { 375 if (!skipModule(M)) 376 run++; 377 return false; 378 } 379 }; 380 char ModuleOpt::ID=0; 381 382 TEST(PassManager, CustomOptPassGate) { 383 LLVMContext Context0; 384 LLVMContext Context1; 385 LLVMContext Context2; 386 CustomOptPassGate SkipOptionalPasses(true); 387 CustomOptPassGate RunOptionalPasses(false); 388 389 Module M0("custom-opt-bisect", Context0); 390 Module M1("custom-opt-bisect", Context1); 391 Module M2("custom-opt-bisect2", Context2); 392 struct ModuleOpt *mOpt0 = new ModuleOpt(); 393 struct ModuleOpt *mOpt1 = new ModuleOpt(); 394 struct ModuleOpt *mOpt2 = new ModuleOpt(); 395 396 mOpt0->run = mOpt1->run = mOpt2->run = 0; 397 398 legacy::PassManager Passes0; 399 legacy::PassManager Passes1; 400 legacy::PassManager Passes2; 401 402 Passes0.add(mOpt0); 403 Passes1.add(mOpt1); 404 Passes2.add(mOpt2); 405 406 Context1.setOptPassGate(SkipOptionalPasses); 407 Context2.setOptPassGate(RunOptionalPasses); 408 409 Passes0.run(M0); 410 Passes1.run(M1); 411 Passes2.run(M2); 412 413 // By default optional passes are run. 414 EXPECT_EQ(1, mOpt0->run); 415 416 // The first context skips optional passes. 417 EXPECT_EQ(0, mOpt1->run); 418 419 // The second context runs optional passes. 420 EXPECT_EQ(1, mOpt2->run); 421 } 422 423 Module *makeLLVMModule(LLVMContext &Context) { 424 // Module Construction 425 Module *mod = new Module("test-mem", Context); 426 mod->setDataLayout("e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-" 427 "i64:64:64-f32:32:32-f64:64:64-v64:64:64-v128:128:128-" 428 "a:0:64-s:64:64-f80:128:128"); 429 mod->setTargetTriple("x86_64-unknown-linux-gnu"); 430 431 // Type Definitions 432 std::vector<Type*>FuncTy_0_args; 433 FunctionType *FuncTy_0 = FunctionType::get( 434 /*Result=*/IntegerType::get(Context, 32), 435 /*Params=*/FuncTy_0_args, 436 /*isVarArg=*/false); 437 438 std::vector<Type*>FuncTy_2_args; 439 FuncTy_2_args.push_back(IntegerType::get(Context, 1)); 440 FunctionType *FuncTy_2 = FunctionType::get( 441 /*Result=*/Type::getVoidTy(Context), 442 /*Params=*/FuncTy_2_args, 443 /*isVarArg=*/false); 444 445 // Function Declarations 446 447 Function* func_test1 = Function::Create( 448 /*Type=*/FuncTy_0, 449 /*Linkage=*/GlobalValue::ExternalLinkage, 450 /*Name=*/"test1", mod); 451 func_test1->setCallingConv(CallingConv::C); 452 AttributeList func_test1_PAL; 453 func_test1->setAttributes(func_test1_PAL); 454 455 Function* func_test2 = Function::Create( 456 /*Type=*/FuncTy_0, 457 /*Linkage=*/GlobalValue::ExternalLinkage, 458 /*Name=*/"test2", mod); 459 func_test2->setCallingConv(CallingConv::C); 460 AttributeList func_test2_PAL; 461 func_test2->setAttributes(func_test2_PAL); 462 463 Function* func_test3 = Function::Create( 464 /*Type=*/FuncTy_0, 465 /*Linkage=*/GlobalValue::InternalLinkage, 466 /*Name=*/"test3", mod); 467 func_test3->setCallingConv(CallingConv::C); 468 AttributeList func_test3_PAL; 469 func_test3->setAttributes(func_test3_PAL); 470 471 Function* func_test4 = Function::Create( 472 /*Type=*/FuncTy_2, 473 /*Linkage=*/GlobalValue::ExternalLinkage, 474 /*Name=*/"test4", mod); 475 func_test4->setCallingConv(CallingConv::C); 476 AttributeList func_test4_PAL; 477 func_test4->setAttributes(func_test4_PAL); 478 479 // Global Variable Declarations 480 481 482 // Constant Definitions 483 484 // Global Variable Definitions 485 486 // Function Definitions 487 488 // Function: test1 (func_test1) 489 { 490 491 BasicBlock *label_entry = 492 BasicBlock::Create(Context, "entry", func_test1, nullptr); 493 494 // Block entry (label_entry) 495 CallInst* int32_3 = CallInst::Create(func_test2, "", label_entry); 496 int32_3->setCallingConv(CallingConv::C); 497 int32_3->setTailCall(false); 498 AttributeList int32_3_PAL; 499 int32_3->setAttributes(int32_3_PAL); 500 501 ReturnInst::Create(Context, int32_3, label_entry); 502 } 503 504 // Function: test2 (func_test2) 505 { 506 507 BasicBlock *label_entry_5 = 508 BasicBlock::Create(Context, "entry", func_test2, nullptr); 509 510 // Block entry (label_entry_5) 511 CallInst* int32_6 = CallInst::Create(func_test3, "", label_entry_5); 512 int32_6->setCallingConv(CallingConv::C); 513 int32_6->setTailCall(false); 514 AttributeList int32_6_PAL; 515 int32_6->setAttributes(int32_6_PAL); 516 517 ReturnInst::Create(Context, int32_6, label_entry_5); 518 } 519 520 // Function: test3 (func_test3) 521 { 522 523 BasicBlock *label_entry_8 = 524 BasicBlock::Create(Context, "entry", func_test3, nullptr); 525 526 // Block entry (label_entry_8) 527 CallInst* int32_9 = CallInst::Create(func_test1, "", label_entry_8); 528 int32_9->setCallingConv(CallingConv::C); 529 int32_9->setTailCall(false); 530 AttributeList int32_9_PAL; 531 int32_9->setAttributes(int32_9_PAL); 532 533 ReturnInst::Create(Context, int32_9, label_entry_8); 534 } 535 536 // Function: test4 (func_test4) 537 { 538 Function::arg_iterator args = func_test4->arg_begin(); 539 Value *int1_f = &*args++; 540 int1_f->setName("f"); 541 542 BasicBlock *label_entry_11 = 543 BasicBlock::Create(Context, "entry", func_test4, nullptr); 544 BasicBlock *label_bb = 545 BasicBlock::Create(Context, "bb", func_test4, nullptr); 546 BasicBlock *label_bb1 = 547 BasicBlock::Create(Context, "bb1", func_test4, nullptr); 548 BasicBlock *label_return = 549 BasicBlock::Create(Context, "return", func_test4, nullptr); 550 551 // Block entry (label_entry_11) 552 auto *AI = new AllocaInst(func_test3->getType(), 0, "func3ptr", 553 label_entry_11); 554 new StoreInst(func_test3, AI, label_entry_11); 555 BranchInst::Create(label_bb, label_entry_11); 556 557 // Block bb (label_bb) 558 BranchInst::Create(label_bb, label_bb1, int1_f, label_bb); 559 560 // Block bb1 (label_bb1) 561 BranchInst::Create(label_bb1, label_return, int1_f, label_bb1); 562 563 // Block return (label_return) 564 ReturnInst::Create(Context, label_return); 565 } 566 return mod; 567 } 568 569 /// Split a simple function which contains only a call and a return into two 570 /// such that the first calls the second and the second whoever was called 571 /// initially. 572 Function *splitSimpleFunction(Function &F) { 573 LLVMContext &Context = F.getContext(); 574 Function *SF = Function::Create(F.getFunctionType(), F.getLinkage(), 575 F.getName() + "b", F.getParent()); 576 F.setName(F.getName() + "a"); 577 BasicBlock *Entry = BasicBlock::Create(Context, "entry", SF, nullptr); 578 CallInst &CI = cast<CallInst>(F.getEntryBlock().front()); 579 CI.clone()->insertBefore(ReturnInst::Create(Context, Entry)); 580 CI.setCalledFunction(SF); 581 return SF; 582 } 583 584 struct CGModifierPass : public CGPass { 585 unsigned NumSCCs = 0; 586 unsigned NumFns = 0; 587 unsigned NumFnDecls = 0; 588 unsigned SetupWorked = 0; 589 unsigned NumExtCalledBefore = 0; 590 unsigned NumExtCalledAfter = 0; 591 592 CallGraphUpdater CGU; 593 594 bool runOnSCC(CallGraphSCC &SCMM) override { 595 ++NumSCCs; 596 for (CallGraphNode *N : SCMM) { 597 if (N->getFunction()){ 598 ++NumFns; 599 NumFnDecls += N->getFunction()->isDeclaration(); 600 } 601 } 602 CGPass::run(); 603 604 CallGraph &CG = const_cast<CallGraph &>(SCMM.getCallGraph()); 605 CallGraphNode *ExtCallingNode = CG.getExternalCallingNode(); 606 NumExtCalledBefore = ExtCallingNode->size(); 607 608 if (SCMM.size() <= 1) 609 return false; 610 611 CallGraphNode *N = *(SCMM.begin()); 612 Function *F = N->getFunction(); 613 Module *M = F->getParent(); 614 Function *Test1F = M->getFunction("test1"); 615 Function *Test2aF = M->getFunction("test2a"); 616 Function *Test2bF = M->getFunction("test2b"); 617 Function *Test3F = M->getFunction("test3"); 618 619 auto InSCC = [&](Function *Fn) { 620 return llvm::any_of(SCMM, [Fn](CallGraphNode *CGN) { 621 return CGN->getFunction() == Fn; 622 }); 623 }; 624 625 if (!Test1F || !Test2aF || !Test2bF || !Test3F || !InSCC(Test1F) || 626 !InSCC(Test2aF) || !InSCC(Test2bF) || !InSCC(Test3F)) 627 return false; 628 629 CallInst *CI = dyn_cast<CallInst>(&Test1F->getEntryBlock().front()); 630 if (!CI || CI->getCalledFunction() != Test2aF) 631 return false; 632 633 SetupWorked += 1; 634 635 // Create a replica of test3 and just move the blocks there. 636 Function *Test3FRepl = Function::Create( 637 /*Type=*/Test3F->getFunctionType(), 638 /*Linkage=*/GlobalValue::InternalLinkage, 639 /*Name=*/"test3repl", Test3F->getParent()); 640 while (!Test3F->empty()) { 641 BasicBlock &BB = Test3F->front(); 642 BB.removeFromParent(); 643 BB.insertInto(Test3FRepl); 644 } 645 646 CGU.initialize(CG, SCMM); 647 648 // Replace test3 with the replica. This is legal as it is actually 649 // internal and the "capturing use" is not really capturing anything. 650 CGU.replaceFunctionWith(*Test3F, *Test3FRepl); 651 Test3F->replaceAllUsesWith(Test3FRepl); 652 653 // Rewrite the call in test1 to point to the replica of 3 not test2. 654 CI->setCalledFunction(Test3FRepl); 655 656 // Delete test2a and test2b and reanalyze 1 as we changed calls inside. 657 CGU.removeFunction(*Test2aF); 658 CGU.removeFunction(*Test2bF); 659 CGU.reanalyzeFunction(*Test1F); 660 661 return true; 662 } 663 664 bool doFinalization(CallGraph &CG) override { 665 CGU.finalize(); 666 // We removed test2 and replaced the internal test3. 667 NumExtCalledAfter = CG.getExternalCallingNode()->size(); 668 return true; 669 } 670 }; 671 672 TEST(PassManager, CallGraphUpdater0) { 673 // SCC#1: test1->test2a->test2b->test3->test1 674 // SCC#2: test4 675 // SCC#3: test3 (the empty function declaration as we replaced it with 676 // test3repl when we visited SCC#1) 677 // SCC#4: test2a->test2b (the empty function declarations as we deleted 678 // these functions when we visited SCC#1) 679 // SCC#5: indirect call node 680 681 LLVMContext Context; 682 std::unique_ptr<Module> M(makeLLVMModule(Context)); 683 ASSERT_EQ(M->getFunctionList().size(), 4U); 684 Function *F = M->getFunction("test2"); 685 Function *SF = splitSimpleFunction(*F); 686 CallInst::Create(F, "", &*SF->getEntryBlock().getFirstInsertionPt()); 687 ASSERT_EQ(M->getFunctionList().size(), 5U); 688 CGModifierPass *P = new CGModifierPass(); 689 legacy::PassManager Passes; 690 Passes.add(P); 691 Passes.run(*M); 692 ASSERT_EQ(P->SetupWorked, 1U); 693 ASSERT_EQ(P->NumSCCs, 4U); 694 ASSERT_EQ(P->NumFns, 6U); 695 ASSERT_EQ(P->NumFnDecls, 1U); 696 ASSERT_EQ(M->getFunctionList().size(), 3U); 697 ASSERT_EQ(P->NumExtCalledBefore, /* test1, 2a, 2b, 3, 4 */ 5U); 698 ASSERT_EQ(P->NumExtCalledAfter, /* test1, 3repl, 4 */ 3U); 699 } 700 701 // Test for call graph SCC pass that replaces all callback call instructions 702 // with clones and updates CallGraph by calling CallGraph::replaceCallEdge() 703 // method. Test is expected to complete successfully after running pass on 704 // all SCCs in the test module. 705 struct CallbackCallsModifierPass : public CGPass { 706 bool runOnSCC(CallGraphSCC &SCC) override { 707 CGPass::run(); 708 709 CallGraph &CG = const_cast<CallGraph &>(SCC.getCallGraph()); 710 711 bool Changed = false; 712 for (CallGraphNode *CGN : SCC) { 713 Function *F = CGN->getFunction(); 714 if (!F || F->isDeclaration()) 715 continue; 716 717 SmallVector<CallBase *, 4u> Calls; 718 for (Use &U : F->uses()) { 719 AbstractCallSite ACS(&U); 720 if (!ACS || !ACS.isCallbackCall() || !ACS.isCallee(&U)) 721 continue; 722 Calls.push_back(cast<CallBase>(ACS.getInstruction())); 723 } 724 if (Calls.empty()) 725 continue; 726 727 for (CallBase *OldCB : Calls) { 728 CallGraphNode *CallerCGN = CG[OldCB->getParent()->getParent()]; 729 assert(any_of(*CallerCGN, 730 [CGN](const CallGraphNode::CallRecord &CallRecord) { 731 return CallRecord.second == CGN; 732 }) && 733 "function is not a callee"); 734 735 CallBase *NewCB = cast<CallBase>(OldCB->clone()); 736 737 NewCB->insertBefore(OldCB); 738 NewCB->takeName(OldCB); 739 740 CallerCGN->replaceCallEdge(*OldCB, *NewCB, CG[F]); 741 742 OldCB->replaceAllUsesWith(NewCB); 743 OldCB->eraseFromParent(); 744 } 745 Changed = true; 746 } 747 return Changed; 748 } 749 }; 750 751 TEST(PassManager, CallbackCallsModifier0) { 752 LLVMContext Context; 753 754 const char *IR = "define void @foo() {\n" 755 " call void @broker(void (i8*)* @callback0, i8* null)\n" 756 " call void @broker(void (i8*)* @callback1, i8* null)\n" 757 " ret void\n" 758 "}\n" 759 "\n" 760 "declare !callback !0 void @broker(void (i8*)*, i8*)\n" 761 "\n" 762 "define internal void @callback0(i8* %arg) {\n" 763 " ret void\n" 764 "}\n" 765 "\n" 766 "define internal void @callback1(i8* %arg) {\n" 767 " ret void\n" 768 "}\n" 769 "\n" 770 "!0 = !{!1}\n" 771 "!1 = !{i64 0, i64 1, i1 false}"; 772 773 SMDiagnostic Err; 774 std::unique_ptr<Module> M = parseAssemblyString(IR, Err, Context); 775 if (!M) 776 Err.print("LegacyPassManagerTest", errs()); 777 778 CallbackCallsModifierPass *P = new CallbackCallsModifierPass(); 779 legacy::PassManager Passes; 780 Passes.add(P); 781 Passes.run(*M); 782 } 783 } 784 } 785 786 INITIALIZE_PASS(ModuleNDM, "mndm", "mndm", false, false) 787 INITIALIZE_PASS_BEGIN(CGPass, "cgp","cgp", false, false) 788 INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) 789 INITIALIZE_PASS_END(CGPass, "cgp","cgp", false, false) 790 INITIALIZE_PASS(FPass, "fp","fp", false, false) 791 INITIALIZE_PASS_BEGIN(LPass, "lp","lp", false, false) 792 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) 793 INITIALIZE_PASS_END(LPass, "lp","lp", false, false) 794