1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Utils/CodeExtractor.h" 10 #include "llvm/AsmParser/Parser.h" 11 #include "llvm/Analysis/AssumptionCache.h" 12 #include "llvm/IR/BasicBlock.h" 13 #include "llvm/IR/Dominators.h" 14 #include "llvm/IR/Instructions.h" 15 #include "llvm/IR/LLVMContext.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/IR/Verifier.h" 18 #include "llvm/IRReader/IRReader.h" 19 #include "llvm/Support/SourceMgr.h" 20 #include "gtest/gtest.h" 21 22 using namespace llvm; 23 24 namespace { 25 BasicBlock *getBlockByName(Function *F, StringRef name) { 26 for (auto &BB : *F) 27 if (BB.getName() == name) 28 return &BB; 29 return nullptr; 30 } 31 32 TEST(CodeExtractor, ExitStub) { 33 LLVMContext Ctx; 34 SMDiagnostic Err; 35 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 36 define i32 @foo(i32 %x, i32 %y, i32 %z) { 37 header: 38 %0 = icmp ugt i32 %x, %y 39 br i1 %0, label %body1, label %body2 40 41 body1: 42 %1 = add i32 %z, 2 43 br label %notExtracted 44 45 body2: 46 %2 = mul i32 %z, 7 47 br label %notExtracted 48 49 notExtracted: 50 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] 51 %4 = add i32 %3, %x 52 ret i32 %4 53 } 54 )invalid", 55 Err, Ctx)); 56 57 Function *Func = M->getFunction("foo"); 58 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"), 59 getBlockByName(Func, "body1"), 60 getBlockByName(Func, "body2") }; 61 62 CodeExtractor CE(Candidates); 63 EXPECT_TRUE(CE.isEligible()); 64 65 CodeExtractorAnalysisCache CEAC(*Func); 66 Function *Outlined = CE.extractCodeRegion(CEAC); 67 EXPECT_TRUE(Outlined); 68 BasicBlock *Exit = getBlockByName(Func, "notExtracted"); 69 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split"); 70 // Ensure that PHI in exit block has only one incoming value (from code 71 // replacer block). 72 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1); 73 // Ensure that there is a PHI in outlined function with 2 incoming values. 74 EXPECT_TRUE(ExitSplit && 75 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2); 76 EXPECT_FALSE(verifyFunction(*Outlined)); 77 EXPECT_FALSE(verifyFunction(*Func)); 78 } 79 80 TEST(CodeExtractor, ExitPHIOnePredFromRegion) { 81 LLVMContext Ctx; 82 SMDiagnostic Err; 83 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 84 define i32 @foo() { 85 header: 86 br i1 undef, label %extracted1, label %pred 87 88 pred: 89 br i1 undef, label %exit1, label %exit2 90 91 extracted1: 92 br i1 undef, label %extracted2, label %exit1 93 94 extracted2: 95 br label %exit2 96 97 exit1: 98 %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ] 99 ret i32 %0 100 101 exit2: 102 %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ] 103 ret i32 %1 104 } 105 )invalid", Err, Ctx)); 106 107 Function *Func = M->getFunction("foo"); 108 SmallVector<BasicBlock *, 2> ExtractedBlocks{ 109 getBlockByName(Func, "extracted1"), 110 getBlockByName(Func, "extracted2") 111 }; 112 113 CodeExtractor CE(ExtractedBlocks); 114 EXPECT_TRUE(CE.isEligible()); 115 116 CodeExtractorAnalysisCache CEAC(*Func); 117 Function *Outlined = CE.extractCodeRegion(CEAC); 118 EXPECT_TRUE(Outlined); 119 BasicBlock *Exit1 = getBlockByName(Func, "exit1"); 120 BasicBlock *Exit2 = getBlockByName(Func, "exit2"); 121 // Ensure that PHIs in exits are not splitted (since that they have only one 122 // incoming value from extracted region). 123 EXPECT_TRUE(Exit1 && 124 cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2); 125 EXPECT_TRUE(Exit2 && 126 cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2); 127 EXPECT_FALSE(verifyFunction(*Outlined)); 128 EXPECT_FALSE(verifyFunction(*Func)); 129 } 130 131 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) { 132 LLVMContext Ctx; 133 SMDiagnostic Err; 134 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 135 declare i8 @hoge() 136 137 define i32 @foo() personality i8* null { 138 entry: 139 %call = invoke i8 @hoge() 140 to label %invoke.cont unwind label %lpad 141 142 invoke.cont: ; preds = %entry 143 unreachable 144 145 lpad: ; preds = %entry 146 %0 = landingpad { i8*, i32 } 147 catch i8* null 148 br i1 undef, label %catch, label %finally.catchall 149 150 catch: ; preds = %lpad 151 %call2 = invoke i8 @hoge() 152 to label %invoke.cont2 unwind label %lpad2 153 154 invoke.cont2: ; preds = %catch 155 %call3 = invoke i8 @hoge() 156 to label %invoke.cont3 unwind label %lpad2 157 158 invoke.cont3: ; preds = %invoke.cont2 159 unreachable 160 161 lpad2: ; preds = %invoke.cont2, %catch 162 %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ] 163 %1 = landingpad { i8*, i32 } 164 catch i8* null 165 br label %finally.catchall 166 167 finally.catchall: ; preds = %lpad33, %lpad 168 %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ] 169 unreachable 170 } 171 )invalid", Err, Ctx)); 172 173 if (!M) { 174 Err.print("unit", errs()); 175 exit(1); 176 } 177 178 Function *Func = M->getFunction("foo"); 179 EXPECT_FALSE(verifyFunction(*Func, &errs())); 180 181 SmallVector<BasicBlock *, 2> ExtractedBlocks{ 182 getBlockByName(Func, "catch"), 183 getBlockByName(Func, "invoke.cont2"), 184 getBlockByName(Func, "invoke.cont3"), 185 getBlockByName(Func, "lpad2") 186 }; 187 188 CodeExtractor CE(ExtractedBlocks); 189 EXPECT_TRUE(CE.isEligible()); 190 191 CodeExtractorAnalysisCache CEAC(*Func); 192 Function *Outlined = CE.extractCodeRegion(CEAC); 193 EXPECT_TRUE(Outlined); 194 EXPECT_FALSE(verifyFunction(*Outlined, &errs())); 195 EXPECT_FALSE(verifyFunction(*Func, &errs())); 196 } 197 198 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) { 199 LLVMContext Ctx; 200 SMDiagnostic Err; 201 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 202 declare i32 @bar() 203 204 define i32 @foo() personality i8* null { 205 entry: 206 %0 = invoke i32 @bar() to label %exit unwind label %lpad 207 208 exit: 209 ret i32 %0 210 211 lpad: 212 %1 = landingpad { i8*, i32 } 213 cleanup 214 resume { i8*, i32 } %1 215 } 216 )invalid", 217 Err, Ctx)); 218 219 Function *Func = M->getFunction("foo"); 220 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"), 221 getBlockByName(Func, "lpad") }; 222 223 CodeExtractor CE(Blocks); 224 EXPECT_TRUE(CE.isEligible()); 225 226 CodeExtractorAnalysisCache CEAC(*Func); 227 Function *Outlined = CE.extractCodeRegion(CEAC); 228 EXPECT_TRUE(Outlined); 229 EXPECT_FALSE(verifyFunction(*Outlined)); 230 EXPECT_FALSE(verifyFunction(*Func)); 231 } 232 233 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) { 234 LLVMContext Ctx; 235 SMDiagnostic Err; 236 std::unique_ptr<Module> M(parseAssemblyString(R"ir( 237 target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" 238 target triple = "aarch64" 239 240 %b = type { i64 } 241 declare void @g(i8*) 242 243 declare void @llvm.assume(i1) #0 244 245 define void @test() { 246 entry: 247 br label %label 248 249 label: 250 %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8 251 %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0 252 %2 = load i64, i64* %1, align 8 253 %3 = icmp ugt i64 %2, 1 254 br i1 %3, label %if.then, label %if.else 255 256 if.then: 257 unreachable 258 259 if.else: 260 call void @g(i8* undef) 261 store i64 undef, i64* null, align 536870912 262 %4 = icmp eq i64 %2, 0 263 call void @llvm.assume(i1 %4) 264 unreachable 265 } 266 267 attributes #0 = { nounwind willreturn } 268 )ir", 269 Err, Ctx)); 270 271 assert(M && "Could not parse module?"); 272 Function *Func = M->getFunction("test"); 273 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") }; 274 AssumptionCache AC(*Func); 275 CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC); 276 EXPECT_TRUE(CE.isEligible()); 277 278 CodeExtractorAnalysisCache CEAC(*Func); 279 Function *Outlined = CE.extractCodeRegion(CEAC); 280 EXPECT_TRUE(Outlined); 281 EXPECT_FALSE(verifyFunction(*Outlined)); 282 EXPECT_FALSE(verifyFunction(*Func)); 283 EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC)); 284 } 285 286 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) { 287 LLVMContext Ctx; 288 SMDiagnostic Err; 289 std::unique_ptr<Module> M(parseAssemblyString(R"ir( 290 target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128" 291 target triple = "x86_64-unknown-linux-gnu" 292 293 declare void @use(i32*) 294 declare void @llvm.lifetime.start.p0i8(i64, i8*) 295 declare void @llvm.lifetime.end.p0i8(i64, i8*) 296 297 define void @foo() { 298 entry: 299 %0 = alloca i32 300 br label %extract 301 302 extract: 303 %1 = bitcast i32* %0 to i8* 304 call void @llvm.lifetime.start.p0i8(i64 4, i8* %1) 305 call void @use(i32* %0) 306 br label %exit 307 308 exit: 309 call void @use(i32* %0) 310 call void @llvm.lifetime.end.p0i8(i64 4, i8* %1) 311 ret void 312 } 313 )ir", 314 Err, Ctx)); 315 316 Function *Func = M->getFunction("foo"); 317 SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")}; 318 319 CodeExtractor CE(Blocks); 320 EXPECT_TRUE(CE.isEligible()); 321 322 CodeExtractorAnalysisCache CEAC(*Func); 323 SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands; 324 BasicBlock *CommonExit = nullptr; 325 CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit); 326 CE.findInputsOutputs(Inputs, Outputs, SinkingCands); 327 EXPECT_EQ(Outputs.size(), 0U); 328 329 Function *Outlined = CE.extractCodeRegion(CEAC); 330 EXPECT_TRUE(Outlined); 331 EXPECT_FALSE(verifyFunction(*Outlined)); 332 EXPECT_FALSE(verifyFunction(*Func)); 333 } 334 } // end anonymous namespace 335