1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Utils/CodeExtractor.h" 10 #include "llvm/AsmParser/Parser.h" 11 #include "llvm/Analysis/AssumptionCache.h" 12 #include "llvm/IR/BasicBlock.h" 13 #include "llvm/IR/Dominators.h" 14 #include "llvm/IR/Instructions.h" 15 #include "llvm/IR/LLVMContext.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/IR/Verifier.h" 18 #include "llvm/IRReader/IRReader.h" 19 #include "llvm/Support/SourceMgr.h" 20 #include "gtest/gtest.h" 21 22 using namespace llvm; 23 24 namespace { 25 BasicBlock *getBlockByName(Function *F, StringRef name) { 26 for (auto &BB : *F) 27 if (BB.getName() == name) 28 return &BB; 29 return nullptr; 30 } 31 32 TEST(CodeExtractor, ExitStub) { 33 LLVMContext Ctx; 34 SMDiagnostic Err; 35 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 36 define i32 @foo(i32 %x, i32 %y, i32 %z) { 37 header: 38 %0 = icmp ugt i32 %x, %y 39 br i1 %0, label %body1, label %body2 40 41 body1: 42 %1 = add i32 %z, 2 43 br label %notExtracted 44 45 body2: 46 %2 = mul i32 %z, 7 47 br label %notExtracted 48 49 notExtracted: 50 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] 51 %4 = add i32 %3, %x 52 ret i32 %4 53 } 54 )invalid", 55 Err, Ctx)); 56 57 Function *Func = M->getFunction("foo"); 58 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"), 59 getBlockByName(Func, "body1"), 60 getBlockByName(Func, "body2") }; 61 62 CodeExtractor CE(Candidates); 63 EXPECT_TRUE(CE.isEligible()); 64 65 CodeExtractorAnalysisCache CEAC(*Func); 66 Function *Outlined = CE.extractCodeRegion(CEAC); 67 EXPECT_TRUE(Outlined); 68 BasicBlock *Exit = getBlockByName(Func, "notExtracted"); 69 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split"); 70 // Ensure that PHI in exit block has only one incoming value (from code 71 // replacer block). 72 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1); 73 // Ensure that there is a PHI in outlined function with 2 incoming values. 74 EXPECT_TRUE(ExitSplit && 75 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2); 76 EXPECT_FALSE(verifyFunction(*Outlined)); 77 EXPECT_FALSE(verifyFunction(*Func)); 78 } 79 80 TEST(CodeExtractor, InputOutputMonitoring) { 81 LLVMContext Ctx; 82 SMDiagnostic Err; 83 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 84 define i32 @foo(i32 %x, i32 %y, i32 %z) { 85 header: 86 %0 = icmp ugt i32 %x, %y 87 br i1 %0, label %body1, label %body2 88 89 body1: 90 %1 = add i32 %z, 2 91 br label %notExtracted 92 93 body2: 94 %2 = mul i32 %z, 7 95 br label %notExtracted 96 97 notExtracted: 98 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] 99 %4 = add i32 %3, %x 100 ret i32 %4 101 } 102 )invalid", 103 Err, Ctx)); 104 105 Function *Func = M->getFunction("foo"); 106 SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"), 107 getBlockByName(Func, "body1"), 108 getBlockByName(Func, "body2")}; 109 110 CodeExtractor CE(Candidates); 111 EXPECT_TRUE(CE.isEligible()); 112 113 CodeExtractorAnalysisCache CEAC(*Func); 114 SetVector<Value *> Inputs, Outputs; 115 Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs); 116 EXPECT_TRUE(Outlined); 117 118 EXPECT_EQ(Inputs.size(), 3u); 119 EXPECT_EQ(Inputs[0], Func->getArg(2)); 120 EXPECT_EQ(Inputs[1], Func->getArg(0)); 121 EXPECT_EQ(Inputs[2], Func->getArg(1)); 122 EXPECT_EQ(Outputs.size(), 1u); 123 StoreInst *SI = cast<StoreInst>(Outlined->getArg(3)->user_back()); 124 Value *OutputVal = SI->getValueOperand(); 125 EXPECT_EQ(Outputs[0], OutputVal); 126 BasicBlock *Exit = getBlockByName(Func, "notExtracted"); 127 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split"); 128 // Ensure that PHI in exit block has only one incoming value (from code 129 // replacer block). 130 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1); 131 // Ensure that there is a PHI in outlined function with 2 incoming values. 132 EXPECT_TRUE(ExitSplit && 133 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2); 134 EXPECT_FALSE(verifyFunction(*Outlined)); 135 EXPECT_FALSE(verifyFunction(*Func)); 136 } 137 138 TEST(CodeExtractor, ExitBlockOrderingPhis) { 139 LLVMContext Ctx; 140 SMDiagnostic Err; 141 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 142 define void @foo(i32 %a, i32 %b) { 143 entry: 144 %0 = alloca i32, align 4 145 br label %test0 146 test0: 147 %c = load i32, i32* %0, align 4 148 br label %test1 149 test1: 150 %e = load i32, i32* %0, align 4 151 br i1 true, label %first, label %test 152 test: 153 %d = load i32, i32* %0, align 4 154 br i1 true, label %first, label %next 155 first: 156 %1 = phi i32 [ %c, %test ], [ %e, %test1 ] 157 ret void 158 next: 159 %2 = add i32 %d, 1 160 %3 = add i32 %e, 1 161 ret void 162 } 163 )invalid", 164 Err, Ctx)); 165 Function *Func = M->getFunction("foo"); 166 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"), 167 getBlockByName(Func, "test1"), 168 getBlockByName(Func, "test") }; 169 170 CodeExtractor CE(Candidates); 171 EXPECT_TRUE(CE.isEligible()); 172 173 CodeExtractorAnalysisCache CEAC(*Func); 174 Function *Outlined = CE.extractCodeRegion(CEAC); 175 EXPECT_TRUE(Outlined); 176 177 BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub"); 178 BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub"); 179 180 Instruction *FirstTerm = FirstExitStub->getTerminator(); 181 ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm); 182 EXPECT_TRUE(FirstReturn); 183 ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue()); 184 EXPECT_TRUE(CIFirst->getLimitedValue() == 1u); 185 186 Instruction *NextTerm = NextExitStub->getTerminator(); 187 ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm); 188 EXPECT_TRUE(NextReturn); 189 ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue()); 190 EXPECT_TRUE(CINext->getLimitedValue() == 0u); 191 192 EXPECT_FALSE(verifyFunction(*Outlined)); 193 EXPECT_FALSE(verifyFunction(*Func)); 194 } 195 196 TEST(CodeExtractor, ExitBlockOrdering) { 197 LLVMContext Ctx; 198 SMDiagnostic Err; 199 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 200 define void @foo(i32 %a, i32 %b) { 201 entry: 202 %0 = alloca i32, align 4 203 br label %test0 204 test0: 205 %c = load i32, i32* %0, align 4 206 br label %test1 207 test1: 208 %e = load i32, i32* %0, align 4 209 br i1 true, label %first, label %test 210 test: 211 %d = load i32, i32* %0, align 4 212 br i1 true, label %first, label %next 213 first: 214 ret void 215 next: 216 %1 = add i32 %d, 1 217 %2 = add i32 %e, 1 218 ret void 219 } 220 )invalid", 221 Err, Ctx)); 222 Function *Func = M->getFunction("foo"); 223 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"), 224 getBlockByName(Func, "test1"), 225 getBlockByName(Func, "test") }; 226 227 CodeExtractor CE(Candidates); 228 EXPECT_TRUE(CE.isEligible()); 229 230 CodeExtractorAnalysisCache CEAC(*Func); 231 Function *Outlined = CE.extractCodeRegion(CEAC); 232 EXPECT_TRUE(Outlined); 233 234 BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub"); 235 BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub"); 236 237 Instruction *FirstTerm = FirstExitStub->getTerminator(); 238 ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm); 239 EXPECT_TRUE(FirstReturn); 240 ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue()); 241 EXPECT_TRUE(CIFirst->getLimitedValue() == 1u); 242 243 Instruction *NextTerm = NextExitStub->getTerminator(); 244 ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm); 245 EXPECT_TRUE(NextReturn); 246 ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue()); 247 EXPECT_TRUE(CINext->getLimitedValue() == 0u); 248 249 EXPECT_FALSE(verifyFunction(*Outlined)); 250 EXPECT_FALSE(verifyFunction(*Func)); 251 } 252 253 TEST(CodeExtractor, ExitPHIOnePredFromRegion) { 254 LLVMContext Ctx; 255 SMDiagnostic Err; 256 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 257 define i32 @foo() { 258 header: 259 br i1 undef, label %extracted1, label %pred 260 261 pred: 262 br i1 undef, label %exit1, label %exit2 263 264 extracted1: 265 br i1 undef, label %extracted2, label %exit1 266 267 extracted2: 268 br label %exit2 269 270 exit1: 271 %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ] 272 ret i32 %0 273 274 exit2: 275 %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ] 276 ret i32 %1 277 } 278 )invalid", Err, Ctx)); 279 280 Function *Func = M->getFunction("foo"); 281 SmallVector<BasicBlock *, 2> ExtractedBlocks{ 282 getBlockByName(Func, "extracted1"), 283 getBlockByName(Func, "extracted2") 284 }; 285 286 CodeExtractor CE(ExtractedBlocks); 287 EXPECT_TRUE(CE.isEligible()); 288 289 CodeExtractorAnalysisCache CEAC(*Func); 290 Function *Outlined = CE.extractCodeRegion(CEAC); 291 EXPECT_TRUE(Outlined); 292 BasicBlock *Exit1 = getBlockByName(Func, "exit1"); 293 BasicBlock *Exit2 = getBlockByName(Func, "exit2"); 294 // Ensure that PHIs in exits are not splitted (since that they have only one 295 // incoming value from extracted region). 296 EXPECT_TRUE(Exit1 && 297 cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2); 298 EXPECT_TRUE(Exit2 && 299 cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2); 300 EXPECT_FALSE(verifyFunction(*Outlined)); 301 EXPECT_FALSE(verifyFunction(*Func)); 302 } 303 304 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) { 305 LLVMContext Ctx; 306 SMDiagnostic Err; 307 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 308 declare i8 @hoge() 309 310 define i32 @foo() personality i8* null { 311 entry: 312 %call = invoke i8 @hoge() 313 to label %invoke.cont unwind label %lpad 314 315 invoke.cont: ; preds = %entry 316 unreachable 317 318 lpad: ; preds = %entry 319 %0 = landingpad { i8*, i32 } 320 catch i8* null 321 br i1 undef, label %catch, label %finally.catchall 322 323 catch: ; preds = %lpad 324 %call2 = invoke i8 @hoge() 325 to label %invoke.cont2 unwind label %lpad2 326 327 invoke.cont2: ; preds = %catch 328 %call3 = invoke i8 @hoge() 329 to label %invoke.cont3 unwind label %lpad2 330 331 invoke.cont3: ; preds = %invoke.cont2 332 unreachable 333 334 lpad2: ; preds = %invoke.cont2, %catch 335 %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ] 336 %1 = landingpad { i8*, i32 } 337 catch i8* null 338 br label %finally.catchall 339 340 finally.catchall: ; preds = %lpad33, %lpad 341 %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ] 342 unreachable 343 } 344 )invalid", Err, Ctx)); 345 346 if (!M) { 347 Err.print("unit", errs()); 348 exit(1); 349 } 350 351 Function *Func = M->getFunction("foo"); 352 EXPECT_FALSE(verifyFunction(*Func, &errs())); 353 354 SmallVector<BasicBlock *, 2> ExtractedBlocks{ 355 getBlockByName(Func, "catch"), 356 getBlockByName(Func, "invoke.cont2"), 357 getBlockByName(Func, "invoke.cont3"), 358 getBlockByName(Func, "lpad2") 359 }; 360 361 CodeExtractor CE(ExtractedBlocks); 362 EXPECT_TRUE(CE.isEligible()); 363 364 CodeExtractorAnalysisCache CEAC(*Func); 365 Function *Outlined = CE.extractCodeRegion(CEAC); 366 EXPECT_TRUE(Outlined); 367 EXPECT_FALSE(verifyFunction(*Outlined, &errs())); 368 EXPECT_FALSE(verifyFunction(*Func, &errs())); 369 } 370 371 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) { 372 LLVMContext Ctx; 373 SMDiagnostic Err; 374 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 375 declare i32 @bar() 376 377 define i32 @foo() personality i8* null { 378 entry: 379 %0 = invoke i32 @bar() to label %exit unwind label %lpad 380 381 exit: 382 ret i32 %0 383 384 lpad: 385 %1 = landingpad { i8*, i32 } 386 cleanup 387 resume { i8*, i32 } %1 388 } 389 )invalid", 390 Err, Ctx)); 391 392 Function *Func = M->getFunction("foo"); 393 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"), 394 getBlockByName(Func, "lpad") }; 395 396 CodeExtractor CE(Blocks); 397 EXPECT_TRUE(CE.isEligible()); 398 399 CodeExtractorAnalysisCache CEAC(*Func); 400 Function *Outlined = CE.extractCodeRegion(CEAC); 401 EXPECT_TRUE(Outlined); 402 EXPECT_FALSE(verifyFunction(*Outlined)); 403 EXPECT_FALSE(verifyFunction(*Func)); 404 } 405 406 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) { 407 LLVMContext Ctx; 408 SMDiagnostic Err; 409 std::unique_ptr<Module> M(parseAssemblyString(R"ir( 410 target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" 411 target triple = "aarch64" 412 413 %b = type { i64 } 414 declare void @g(i8*) 415 416 declare void @llvm.assume(i1) #0 417 418 define void @test() { 419 entry: 420 br label %label 421 422 label: 423 %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8 424 %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0 425 %2 = load i64, i64* %1, align 8 426 %3 = icmp ugt i64 %2, 1 427 br i1 %3, label %if.then, label %if.else 428 429 if.then: 430 unreachable 431 432 if.else: 433 call void @g(i8* undef) 434 store i64 undef, i64* null, align 536870912 435 %4 = icmp eq i64 %2, 0 436 call void @llvm.assume(i1 %4) 437 unreachable 438 } 439 440 attributes #0 = { nounwind willreturn } 441 )ir", 442 Err, Ctx)); 443 444 assert(M && "Could not parse module?"); 445 Function *Func = M->getFunction("test"); 446 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") }; 447 AssumptionCache AC(*Func); 448 CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC); 449 EXPECT_TRUE(CE.isEligible()); 450 451 CodeExtractorAnalysisCache CEAC(*Func); 452 Function *Outlined = CE.extractCodeRegion(CEAC); 453 EXPECT_TRUE(Outlined); 454 EXPECT_FALSE(verifyFunction(*Outlined)); 455 EXPECT_FALSE(verifyFunction(*Func)); 456 EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC)); 457 } 458 459 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) { 460 LLVMContext Ctx; 461 SMDiagnostic Err; 462 std::unique_ptr<Module> M(parseAssemblyString(R"ir( 463 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" 464 target triple = "x86_64-unknown-linux-gnu" 465 466 declare void @use(i32*) 467 declare void @llvm.lifetime.start.p0i8(i64, i8*) 468 declare void @llvm.lifetime.end.p0i8(i64, i8*) 469 470 define void @foo() { 471 entry: 472 %0 = alloca i32 473 br label %extract 474 475 extract: 476 %1 = bitcast i32* %0 to i8* 477 call void @llvm.lifetime.start.p0i8(i64 4, i8* %1) 478 call void @use(i32* %0) 479 br label %exit 480 481 exit: 482 call void @use(i32* %0) 483 call void @llvm.lifetime.end.p0i8(i64 4, i8* %1) 484 ret void 485 } 486 )ir", 487 Err, Ctx)); 488 489 Function *Func = M->getFunction("foo"); 490 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")}; 491 492 CodeExtractor CE(Blocks); 493 EXPECT_TRUE(CE.isEligible()); 494 495 CodeExtractorAnalysisCache CEAC(*Func); 496 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands; 497 BasicBlock *CommonExit = nullptr; 498 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); 499 CE.findInputsOutputs(Inputs, Outputs, SinkingCands); 500 EXPECT_EQ(Outputs.size(), 0U); 501 502 Function *Outlined = CE.extractCodeRegion(CEAC); 503 EXPECT_TRUE(Outlined); 504 EXPECT_FALSE(verifyFunction(*Outlined)); 505 EXPECT_FALSE(verifyFunction(*Func)); 506 } 507 } // end anonymous namespace 508