1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Utils/CodeExtractor.h" 10 #include "llvm/Analysis/AssumptionCache.h" 11 #include "llvm/AsmParser/Parser.h" 12 #include "llvm/IR/BasicBlock.h" 13 #include "llvm/IR/Constants.h" 14 #include "llvm/IR/Dominators.h" 15 #include "llvm/IR/InstIterator.h" 16 #include "llvm/IR/Instructions.h" 17 #include "llvm/IR/LLVMContext.h" 18 #include "llvm/IR/Module.h" 19 #include "llvm/IR/Verifier.h" 20 #include "llvm/IRReader/IRReader.h" 21 #include "llvm/Support/SourceMgr.h" 22 #include "gtest/gtest.h" 23 24 using namespace llvm; 25 26 namespace { 27 BasicBlock *getBlockByName(Function *F, StringRef name) { 28 for (auto &BB : *F) 29 if (BB.getName() == name) 30 return &BB; 31 return nullptr; 32 } 33 34 Instruction *getInstByName(Function *F, StringRef Name) { 35 for (Instruction &I : instructions(F)) 36 if (I.getName() == Name) 37 return &I; 38 return nullptr; 39 } 40 41 TEST(CodeExtractor, ExitStub) { 42 LLVMContext Ctx; 43 SMDiagnostic Err; 44 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 45 define i32 @foo(i32 %x, i32 %y, i32 %z) { 46 header: 47 %0 = icmp ugt i32 %x, %y 48 br i1 %0, label %body1, label %body2 49 50 body1: 51 %1 = add i32 %z, 2 52 br label %notExtracted 53 54 body2: 55 %2 = mul i32 %z, 7 56 br label %notExtracted 57 58 notExtracted: 59 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] 60 %4 = add i32 %3, %x 61 ret i32 %4 62 } 63 )invalid", 64 Err, Ctx)); 65 66 Function *Func = M->getFunction("foo"); 67 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"), 68 getBlockByName(Func, "body1"), 69 getBlockByName(Func, "body2") }; 70 71 CodeExtractor CE(Candidates); 72 EXPECT_TRUE(CE.isEligible()); 73 74 CodeExtractorAnalysisCache CEAC(*Func); 75 Function *Outlined = CE.extractCodeRegion(CEAC); 76 EXPECT_TRUE(Outlined); 77 BasicBlock *Exit = getBlockByName(Func, "notExtracted"); 78 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split"); 79 // Ensure that PHI in exit block has only one incoming value (from code 80 // replacer block). 81 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1); 82 // Ensure that there is a PHI in outlined function with 2 incoming values. 83 EXPECT_TRUE(ExitSplit && 84 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2); 85 EXPECT_FALSE(verifyFunction(*Outlined)); 86 EXPECT_FALSE(verifyFunction(*Func)); 87 } 88 89 TEST(CodeExtractor, InputOutputMonitoring) { 90 LLVMContext Ctx; 91 SMDiagnostic Err; 92 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 93 define i32 @foo(i32 %x, i32 %y, i32 %z) { 94 header: 95 %0 = icmp ugt i32 %x, %y 96 br i1 %0, label %body1, label %body2 97 98 body1: 99 %1 = add i32 %z, 2 100 br label %notExtracted 101 102 body2: 103 %2 = mul i32 %z, 7 104 br label %notExtracted 105 106 notExtracted: 107 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] 108 %4 = add i32 %3, %x 109 ret i32 %4 110 } 111 )invalid", 112 Err, Ctx)); 113 114 Function *Func = M->getFunction("foo"); 115 SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"), 116 getBlockByName(Func, "body1"), 117 getBlockByName(Func, "body2")}; 118 119 CodeExtractor CE(Candidates); 120 EXPECT_TRUE(CE.isEligible()); 121 122 CodeExtractorAnalysisCache CEAC(*Func); 123 SetVector<Value *> Inputs, Outputs; 124 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); 125 EXPECT_TRUE(Outlined); 126 127 EXPECT_EQ(Inputs.size(), 3u); 128 EXPECT_EQ(Inputs[0], Func->getArg(2)); 129 EXPECT_EQ(Inputs[1], Func->getArg(0)); 130 EXPECT_EQ(Inputs[2], Func->getArg(1)); 131 EXPECT_EQ(Outputs.size(), 1u); 132 StoreInst *SI = cast<StoreInst>(Outlined->getArg(3)->user_back()); 133 Value *OutputVal = SI->getValueOperand(); 134 EXPECT_EQ(Outputs[0], OutputVal); 135 BasicBlock *Exit = getBlockByName(Func, "notExtracted"); 136 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split"); 137 // Ensure that PHI in exit block has only one incoming value (from code 138 // replacer block). 139 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1); 140 // Ensure that there is a PHI in outlined function with 2 incoming values. 141 EXPECT_TRUE(ExitSplit && 142 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2); 143 EXPECT_FALSE(verifyFunction(*Outlined)); 144 EXPECT_FALSE(verifyFunction(*Func)); 145 } 146 147 TEST(CodeExtractor, ExitBlockOrderingPhis) { 148 LLVMContext Ctx; 149 SMDiagnostic Err; 150 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 151 define void @foo(i32 %a, i32 %b) { 152 entry: 153 %0 = alloca i32, align 4 154 br label %test0 155 test0: 156 %c = load i32, i32* %0, align 4 157 br label %test1 158 test1: 159 %e = load i32, i32* %0, align 4 160 br i1 true, label %first, label %test 161 test: 162 %d = load i32, i32* %0, align 4 163 br i1 true, label %first, label %next 164 first: 165 %1 = phi i32 [ %c, %test ], [ %e, %test1 ] 166 ret void 167 next: 168 %2 = add i32 %d, 1 169 %3 = add i32 %e, 1 170 ret void 171 } 172 )invalid", 173 Err, Ctx)); 174 Function *Func = M->getFunction("foo"); 175 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"), 176 getBlockByName(Func, "test1"), 177 getBlockByName(Func, "test") }; 178 179 CodeExtractor CE(Candidates); 180 EXPECT_TRUE(CE.isEligible()); 181 182 CodeExtractorAnalysisCache CEAC(*Func); 183 Function *Outlined = CE.extractCodeRegion(CEAC); 184 EXPECT_TRUE(Outlined); 185 186 BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub"); 187 BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub"); 188 189 Instruction *FirstTerm = FirstExitStub->getTerminator(); 190 ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm); 191 EXPECT_TRUE(FirstReturn); 192 ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue()); 193 EXPECT_TRUE(CIFirst->getLimitedValue() == 1u); 194 195 Instruction *NextTerm = NextExitStub->getTerminator(); 196 ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm); 197 EXPECT_TRUE(NextReturn); 198 ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue()); 199 EXPECT_TRUE(CINext->getLimitedValue() == 0u); 200 201 EXPECT_FALSE(verifyFunction(*Outlined)); 202 EXPECT_FALSE(verifyFunction(*Func)); 203 } 204 205 TEST(CodeExtractor, ExitBlockOrdering) { 206 LLVMContext Ctx; 207 SMDiagnostic Err; 208 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 209 define void @foo(i32 %a, i32 %b) { 210 entry: 211 %0 = alloca i32, align 4 212 br label %test0 213 test0: 214 %c = load i32, i32* %0, align 4 215 br label %test1 216 test1: 217 %e = load i32, i32* %0, align 4 218 br i1 true, label %first, label %test 219 test: 220 %d = load i32, i32* %0, align 4 221 br i1 true, label %first, label %next 222 first: 223 ret void 224 next: 225 %1 = add i32 %d, 1 226 %2 = add i32 %e, 1 227 ret void 228 } 229 )invalid", 230 Err, Ctx)); 231 Function *Func = M->getFunction("foo"); 232 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"), 233 getBlockByName(Func, "test1"), 234 getBlockByName(Func, "test") }; 235 236 CodeExtractor CE(Candidates); 237 EXPECT_TRUE(CE.isEligible()); 238 239 CodeExtractorAnalysisCache CEAC(*Func); 240 Function *Outlined = CE.extractCodeRegion(CEAC); 241 EXPECT_TRUE(Outlined); 242 243 BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub"); 244 BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub"); 245 246 Instruction *FirstTerm = FirstExitStub->getTerminator(); 247 ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm); 248 EXPECT_TRUE(FirstReturn); 249 ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue()); 250 EXPECT_TRUE(CIFirst->getLimitedValue() == 1u); 251 252 Instruction *NextTerm = NextExitStub->getTerminator(); 253 ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm); 254 EXPECT_TRUE(NextReturn); 255 ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue()); 256 EXPECT_TRUE(CINext->getLimitedValue() == 0u); 257 258 EXPECT_FALSE(verifyFunction(*Outlined)); 259 EXPECT_FALSE(verifyFunction(*Func)); 260 } 261 262 TEST(CodeExtractor, ExitPHIOnePredFromRegion) { 263 LLVMContext Ctx; 264 SMDiagnostic Err; 265 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 266 define i32 @foo() { 267 header: 268 br i1 undef, label %extracted1, label %pred 269 270 pred: 271 br i1 undef, label %exit1, label %exit2 272 273 extracted1: 274 br i1 undef, label %extracted2, label %exit1 275 276 extracted2: 277 br label %exit2 278 279 exit1: 280 %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ] 281 ret i32 %0 282 283 exit2: 284 %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ] 285 ret i32 %1 286 } 287 )invalid", Err, Ctx)); 288 289 Function *Func = M->getFunction("foo"); 290 SmallVector<BasicBlock *, 2> ExtractedBlocks{ 291 getBlockByName(Func, "extracted1"), 292 getBlockByName(Func, "extracted2") 293 }; 294 295 CodeExtractor CE(ExtractedBlocks); 296 EXPECT_TRUE(CE.isEligible()); 297 298 CodeExtractorAnalysisCache CEAC(*Func); 299 Function *Outlined = CE.extractCodeRegion(CEAC); 300 EXPECT_TRUE(Outlined); 301 BasicBlock *Exit1 = getBlockByName(Func, "exit1"); 302 BasicBlock *Exit2 = getBlockByName(Func, "exit2"); 303 // Ensure that PHIs in exits are not splitted (since that they have only one 304 // incoming value from extracted region). 305 EXPECT_TRUE(Exit1 && 306 cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2); 307 EXPECT_TRUE(Exit2 && 308 cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2); 309 EXPECT_FALSE(verifyFunction(*Outlined)); 310 EXPECT_FALSE(verifyFunction(*Func)); 311 } 312 313 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) { 314 LLVMContext Ctx; 315 SMDiagnostic Err; 316 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 317 declare i8 @hoge() 318 319 define i32 @foo() personality i8* null { 320 entry: 321 %call = invoke i8 @hoge() 322 to label %invoke.cont unwind label %lpad 323 324 invoke.cont: ; preds = %entry 325 unreachable 326 327 lpad: ; preds = %entry 328 %0 = landingpad { i8*, i32 } 329 catch i8* null 330 br i1 undef, label %catch, label %finally.catchall 331 332 catch: ; preds = %lpad 333 %call2 = invoke i8 @hoge() 334 to label %invoke.cont2 unwind label %lpad2 335 336 invoke.cont2: ; preds = %catch 337 %call3 = invoke i8 @hoge() 338 to label %invoke.cont3 unwind label %lpad2 339 340 invoke.cont3: ; preds = %invoke.cont2 341 unreachable 342 343 lpad2: ; preds = %invoke.cont2, %catch 344 %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ] 345 %1 = landingpad { i8*, i32 } 346 catch i8* null 347 br label %finally.catchall 348 349 finally.catchall: ; preds = %lpad33, %lpad 350 %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ] 351 unreachable 352 } 353 )invalid", Err, Ctx)); 354 355 if (!M) { 356 Err.print("unit", errs()); 357 exit(1); 358 } 359 360 Function *Func = M->getFunction("foo"); 361 EXPECT_FALSE(verifyFunction(*Func, &errs())); 362 363 SmallVector<BasicBlock *, 2> ExtractedBlocks{ 364 getBlockByName(Func, "catch"), 365 getBlockByName(Func, "invoke.cont2"), 366 getBlockByName(Func, "invoke.cont3"), 367 getBlockByName(Func, "lpad2") 368 }; 369 370 CodeExtractor CE(ExtractedBlocks); 371 EXPECT_TRUE(CE.isEligible()); 372 373 CodeExtractorAnalysisCache CEAC(*Func); 374 Function *Outlined = CE.extractCodeRegion(CEAC); 375 EXPECT_TRUE(Outlined); 376 EXPECT_FALSE(verifyFunction(*Outlined, &errs())); 377 EXPECT_FALSE(verifyFunction(*Func, &errs())); 378 } 379 380 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) { 381 LLVMContext Ctx; 382 SMDiagnostic Err; 383 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 384 declare i32 @bar() 385 386 define i32 @foo() personality i8* null { 387 entry: 388 %0 = invoke i32 @bar() to label %exit unwind label %lpad 389 390 exit: 391 ret i32 %0 392 393 lpad: 394 %1 = landingpad { i8*, i32 } 395 cleanup 396 resume { i8*, i32 } %1 397 } 398 )invalid", 399 Err, Ctx)); 400 401 Function *Func = M->getFunction("foo"); 402 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"), 403 getBlockByName(Func, "lpad") }; 404 405 CodeExtractor CE(Blocks); 406 EXPECT_TRUE(CE.isEligible()); 407 408 CodeExtractorAnalysisCache CEAC(*Func); 409 Function *Outlined = CE.extractCodeRegion(CEAC); 410 EXPECT_TRUE(Outlined); 411 EXPECT_FALSE(verifyFunction(*Outlined)); 412 EXPECT_FALSE(verifyFunction(*Func)); 413 } 414 415 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) { 416 LLVMContext Ctx; 417 SMDiagnostic Err; 418 std::unique_ptr<Module> M(parseAssemblyString(R"ir( 419 target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" 420 target triple = "aarch64" 421 422 %b = type { i64 } 423 declare void @g(i8*) 424 425 declare void @llvm.assume(i1) #0 426 427 define void @test() { 428 entry: 429 br label %label 430 431 label: 432 %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8 433 %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0 434 %2 = load i64, i64* %1, align 8 435 %3 = icmp ugt i64 %2, 1 436 br i1 %3, label %if.then, label %if.else 437 438 if.then: 439 unreachable 440 441 if.else: 442 call void @g(i8* undef) 443 store i64 undef, i64* null, align 536870912 444 %4 = icmp eq i64 %2, 0 445 call void @llvm.assume(i1 %4) 446 unreachable 447 } 448 449 attributes #0 = { nounwind willreturn } 450 )ir", 451 Err, Ctx)); 452 453 assert(M && "Could not parse module?"); 454 Function *Func = M->getFunction("test"); 455 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") }; 456 AssumptionCache AC(*Func); 457 CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC); 458 EXPECT_TRUE(CE.isEligible()); 459 460 CodeExtractorAnalysisCache CEAC(*Func); 461 Function *Outlined = CE.extractCodeRegion(CEAC); 462 EXPECT_TRUE(Outlined); 463 EXPECT_FALSE(verifyFunction(*Outlined)); 464 EXPECT_FALSE(verifyFunction(*Func)); 465 EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC)); 466 } 467 468 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) { 469 LLVMContext Ctx; 470 SMDiagnostic Err; 471 std::unique_ptr<Module> M(parseAssemblyString(R"ir( 472 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" 473 target triple = "x86_64-unknown-linux-gnu" 474 475 declare void @use(i32*) 476 declare void @llvm.lifetime.start.p0i8(i64, i8*) 477 declare void @llvm.lifetime.end.p0i8(i64, i8*) 478 479 define void @foo() { 480 entry: 481 %0 = alloca i32 482 br label %extract 483 484 extract: 485 %1 = bitcast i32* %0 to i8* 486 call void @llvm.lifetime.start.p0i8(i64 4, i8* %1) 487 call void @use(i32* %0) 488 br label %exit 489 490 exit: 491 call void @use(i32* %0) 492 call void @llvm.lifetime.end.p0i8(i64 4, i8* %1) 493 ret void 494 } 495 )ir", 496 Err, Ctx)); 497 498 Function *Func = M->getFunction("foo"); 499 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")}; 500 501 CodeExtractor CE(Blocks); 502 EXPECT_TRUE(CE.isEligible()); 503 504 CodeExtractorAnalysisCache CEAC(*Func); 505 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands; 506 BasicBlock *CommonExit = nullptr; 507 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); 508 CE.findInputsOutputs(Inputs, Outputs, SinkingCands); 509 EXPECT_EQ(Outputs.size(), 0U); 510 511 Function *Outlined = CE.extractCodeRegion(CEAC); 512 EXPECT_TRUE(Outlined); 513 EXPECT_FALSE(verifyFunction(*Outlined)); 514 EXPECT_FALSE(verifyFunction(*Func)); 515 } 516 517 TEST(CodeExtractor, PartialAggregateArgs) { 518 LLVMContext Ctx; 519 SMDiagnostic Err; 520 std::unique_ptr<Module> M(parseAssemblyString(R"ir( 521 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" 522 target triple = "x86_64-unknown-linux-gnu" 523 524 ; use different types such that an index mismatch will result in a type mismatch during verification. 525 declare void @use16(i16) 526 declare void @use32(i32) 527 declare void @use64(i64) 528 529 define void @foo(i16 %a, i32 %b, i64 %c) { 530 entry: 531 br label %extract 532 533 extract: 534 call void @use16(i16 %a) 535 call void @use32(i32 %b) 536 call void @use64(i64 %c) 537 %d = add i16 21, 21 538 %e = add i32 21, 21 539 %f = add i64 21, 21 540 br label %exit 541 542 exit: 543 call void @use16(i16 %d) 544 call void @use32(i32 %e) 545 call void @use64(i64 %f) 546 ret void 547 } 548 )ir", 549 Err, Ctx)); 550 551 Function *Func = M->getFunction("foo"); 552 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")}; 553 554 // Create the CodeExtractor with arguments aggregation enabled. 555 CodeExtractor CE(Blocks, /* DominatorTree */ nullptr, 556 /* AggregateArgs */ true); 557 EXPECT_TRUE(CE.isEligible()); 558 559 CodeExtractorAnalysisCache CEAC(*Func); 560 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands; 561 BasicBlock *CommonExit = nullptr; 562 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); 563 CE.findInputsOutputs(Inputs, Outputs, SinkingCands); 564 // Exclude the middle input and output from the argument aggregate. 565 CE.excludeArgFromAggregate(Inputs[1]); 566 CE.excludeArgFromAggregate(Outputs[1]); 567 568 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); 569 EXPECT_TRUE(Outlined); 570 // Expect 3 arguments in the outlined function: the excluded input, the 571 // excluded output, and the struct aggregate for the remaining inputs. 572 EXPECT_EQ(Outlined->arg_size(), 3U); 573 EXPECT_FALSE(verifyFunction(*Outlined)); 574 EXPECT_FALSE(verifyFunction(*Func)); 575 } 576 577 TEST(CodeExtractor, AllocaBlock) { 578 LLVMContext Ctx; 579 SMDiagnostic Err; 580 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 581 define i32 @foo(i32 %x, i32 %y, i32 %z) { 582 entry: 583 br label %allocas 584 585 allocas: 586 br label %body 587 588 body: 589 %w = add i32 %x, %y 590 br label %notExtracted 591 592 notExtracted: 593 %r = add i32 %w, %x 594 ret i32 %r 595 } 596 )invalid", 597 Err, Ctx)); 598 599 Function *Func = M->getFunction("foo"); 600 SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "body")}; 601 602 BasicBlock *AllocaBlock = getBlockByName(Func, "allocas"); 603 CodeExtractor CE(Candidates, nullptr, true, nullptr, nullptr, nullptr, false, 604 false, AllocaBlock); 605 CE.excludeArgFromAggregate(Func->getArg(0)); 606 CE.excludeArgFromAggregate(getInstByName(Func, "w")); 607 EXPECT_TRUE(CE.isEligible()); 608 609 CodeExtractorAnalysisCache CEAC(*Func); 610 SetVector<Value *> Inputs, Outputs; 611 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); 612 EXPECT_TRUE(Outlined); 613 EXPECT_FALSE(verifyFunction(*Outlined)); 614 EXPECT_FALSE(verifyFunction(*Func)); 615 616 // The only added allocas may be in the dedicated alloca block. There should 617 // be one alloca for the struct, and another one for the reload value. 618 int NumAllocas = 0; 619 for (Instruction &I : instructions(Func)) { 620 if (!isa<AllocaInst>(I)) 621 continue; 622 EXPECT_EQ(I.getParent(), AllocaBlock); 623 NumAllocas += 1; 624 } 625 EXPECT_EQ(NumAllocas, 2); 626 } 627 628 /// Regression test to ensure we don't crash trying to set the name of the ptr 629 /// argument 630 TEST(CodeExtractor, PartialAggregateArgs2) { 631 LLVMContext Ctx; 632 SMDiagnostic Err; 633 std::unique_ptr<Module> M(parseAssemblyString(R"ir( 634 declare void @usei(i32) 635 declare void @usep(ptr) 636 637 define void @foo(i32 %a, i32 %b, ptr %p) { 638 entry: 639 br label %extract 640 641 extract: 642 call void @usei(i32 %a) 643 call void @usei(i32 %b) 644 call void @usep(ptr %p) 645 br label %exit 646 647 exit: 648 ret void 649 } 650 )ir", 651 Err, Ctx)); 652 653 Function *Func = M->getFunction("foo"); 654 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")}; 655 656 // Create the CodeExtractor with arguments aggregation enabled. 657 CodeExtractor CE(Blocks, /* DominatorTree */ nullptr, 658 /* AggregateArgs */ true); 659 EXPECT_TRUE(CE.isEligible()); 660 661 CodeExtractorAnalysisCache CEAC(*Func); 662 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands; 663 BasicBlock *CommonExit = nullptr; 664 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); 665 CE.findInputsOutputs(Inputs, Outputs, SinkingCands); 666 // Exclude the last input from the argument aggregate. 667 CE.excludeArgFromAggregate(Inputs[2]); 668 669 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); 670 EXPECT_TRUE(Outlined); 671 EXPECT_FALSE(verifyFunction(*Outlined)); 672 EXPECT_FALSE(verifyFunction(*Func)); 673 } 674 675 TEST(CodeExtractor, OpenMPAggregateArgs) { 676 LLVMContext Ctx; 677 SMDiagnostic Err; 678 std::unique_ptr<Module> M(parseAssemblyString(R"ir( 679 target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9" 680 target triple = "amdgcn-amd-amdhsa" 681 682 define void @foo(ptr %0) { 683 %2= alloca ptr, align 8, addrspace(5) 684 %3 = addrspacecast ptr addrspace(5) %2 to ptr 685 store ptr %0, ptr %3, align 8 686 %4 = load ptr, ptr %3, align 8 687 br label %entry 688 689 entry: 690 br label %extract 691 692 extract: 693 store i64 10, ptr %4, align 4 694 br label %exit 695 696 exit: 697 ret void 698 } 699 )ir", 700 Err, Ctx)); 701 Function *Func = M->getFunction("foo"); 702 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")}; 703 704 // Create the CodeExtractor with arguments aggregation enabled. 705 // Outlined function argument should be declared in 0 address space 706 // even if the default alloca address space is 5. 707 CodeExtractor CE(Blocks, /* DominatorTree */ nullptr, 708 /* AggregateArgs */ true, /* BlockFrequencyInfo */ nullptr, 709 /* BranchProbabilityInfo */ nullptr, 710 /* AssumptionCache */ nullptr, 711 /* AllowVarArgs */ true, 712 /* AllowAlloca */ true, 713 /* AllocaBlock*/ &Func->getEntryBlock(), 714 /* Suffix */ ".outlined", 715 /* ArgsInZeroAddressSpace */ true); 716 717 EXPECT_TRUE(CE.isEligible()); 718 719 CodeExtractorAnalysisCache CEAC(*Func); 720 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands; 721 BasicBlock *CommonExit = nullptr; 722 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); 723 CE.findInputsOutputs(Inputs, Outputs, SinkingCands); 724 725 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); 726 EXPECT_TRUE(Outlined); 727 EXPECT_EQ(Outlined->arg_size(), 1U); 728 // Check address space of outlined argument is ptr in address space 0 729 EXPECT_EQ(Outlined->getArg(0)->getType(), 730 PointerType::get(M->getContext(), 0)); 731 EXPECT_FALSE(verifyFunction(*Outlined)); 732 EXPECT_FALSE(verifyFunction(*Func)); 733 } 734 } // end anonymous namespace 735