1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Utils/CodeExtractor.h" 10 #include "llvm/AsmParser/Parser.h" 11 #include "llvm/Analysis/AssumptionCache.h" 12 #include "llvm/IR/BasicBlock.h" 13 #include "llvm/IR/Dominators.h" 14 #include "llvm/IR/Instructions.h" 15 #include "llvm/IR/LLVMContext.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/IR/Verifier.h" 18 #include "llvm/IRReader/IRReader.h" 19 #include "llvm/Support/SourceMgr.h" 20 #include "gtest/gtest.h" 21 22 using namespace llvm; 23 24 namespace { 25 BasicBlock *getBlockByName(Function *F, StringRef name) { 26 for (auto &BB : *F) 27 if (BB.getName() == name) 28 return &BB; 29 return nullptr; 30 } 31 32 TEST(CodeExtractor, ExitStub) { 33 LLVMContext Ctx; 34 SMDiagnostic Err; 35 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 36 define i32 @foo(i32 %x, i32 %y, i32 %z) { 37 header: 38 %0 = icmp ugt i32 %x, %y 39 br i1 %0, label %body1, label %body2 40 41 body1: 42 %1 = add i32 %z, 2 43 br label %notExtracted 44 45 body2: 46 %2 = mul i32 %z, 7 47 br label %notExtracted 48 49 notExtracted: 50 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] 51 %4 = add i32 %3, %x 52 ret i32 %4 53 } 54 )invalid", 55 Err, Ctx)); 56 57 Function *Func = M->getFunction("foo"); 58 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"), 59 getBlockByName(Func, "body1"), 60 getBlockByName(Func, "body2") }; 61 62 CodeExtractor CE(Candidates); 63 EXPECT_TRUE(CE.isEligible()); 64 65 Function *Outlined = CE.extractCodeRegion(); 66 EXPECT_TRUE(Outlined); 67 BasicBlock *Exit = getBlockByName(Func, "notExtracted"); 68 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split"); 69 // Ensure that PHI in exit block has only one incoming value (from code 70 // replacer block). 71 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1); 72 // Ensure that there is a PHI in outlined function with 2 incoming values. 73 EXPECT_TRUE(ExitSplit && 74 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2); 75 EXPECT_FALSE(verifyFunction(*Outlined)); 76 EXPECT_FALSE(verifyFunction(*Func)); 77 } 78 79 TEST(CodeExtractor, ExitPHIOnePredFromRegion) { 80 LLVMContext Ctx; 81 SMDiagnostic Err; 82 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 83 define i32 @foo() { 84 header: 85 br i1 undef, label %extracted1, label %pred 86 87 pred: 88 br i1 undef, label %exit1, label %exit2 89 90 extracted1: 91 br i1 undef, label %extracted2, label %exit1 92 93 extracted2: 94 br label %exit2 95 96 exit1: 97 %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ] 98 ret i32 %0 99 100 exit2: 101 %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ] 102 ret i32 %1 103 } 104 )invalid", Err, Ctx)); 105 106 Function *Func = M->getFunction("foo"); 107 SmallVector<BasicBlock *, 2> ExtractedBlocks{ 108 getBlockByName(Func, "extracted1"), 109 getBlockByName(Func, "extracted2") 110 }; 111 112 CodeExtractor CE(ExtractedBlocks); 113 EXPECT_TRUE(CE.isEligible()); 114 115 Function *Outlined = CE.extractCodeRegion(); 116 EXPECT_TRUE(Outlined); 117 BasicBlock *Exit1 = getBlockByName(Func, "exit1"); 118 BasicBlock *Exit2 = getBlockByName(Func, "exit2"); 119 // Ensure that PHIs in exits are not splitted (since that they have only one 120 // incoming value from extracted region). 121 EXPECT_TRUE(Exit1 && 122 cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2); 123 EXPECT_TRUE(Exit2 && 124 cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2); 125 EXPECT_FALSE(verifyFunction(*Outlined)); 126 EXPECT_FALSE(verifyFunction(*Func)); 127 } 128 129 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) { 130 LLVMContext Ctx; 131 SMDiagnostic Err; 132 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 133 declare i8 @hoge() 134 135 define i32 @foo() personality i8* null { 136 entry: 137 %call = invoke i8 @hoge() 138 to label %invoke.cont unwind label %lpad 139 140 invoke.cont: ; preds = %entry 141 unreachable 142 143 lpad: ; preds = %entry 144 %0 = landingpad { i8*, i32 } 145 catch i8* null 146 br i1 undef, label %catch, label %finally.catchall 147 148 catch: ; preds = %lpad 149 %call2 = invoke i8 @hoge() 150 to label %invoke.cont2 unwind label %lpad2 151 152 invoke.cont2: ; preds = %catch 153 %call3 = invoke i8 @hoge() 154 to label %invoke.cont3 unwind label %lpad2 155 156 invoke.cont3: ; preds = %invoke.cont2 157 unreachable 158 159 lpad2: ; preds = %invoke.cont2, %catch 160 %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ] 161 %1 = landingpad { i8*, i32 } 162 catch i8* null 163 br label %finally.catchall 164 165 finally.catchall: ; preds = %lpad33, %lpad 166 %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ] 167 unreachable 168 } 169 )invalid", Err, Ctx)); 170 171 if (!M) { 172 Err.print("unit", errs()); 173 exit(1); 174 } 175 176 Function *Func = M->getFunction("foo"); 177 EXPECT_FALSE(verifyFunction(*Func, &errs())); 178 179 SmallVector<BasicBlock *, 2> ExtractedBlocks{ 180 getBlockByName(Func, "catch"), 181 getBlockByName(Func, "invoke.cont2"), 182 getBlockByName(Func, "invoke.cont3"), 183 getBlockByName(Func, "lpad2") 184 }; 185 186 CodeExtractor CE(ExtractedBlocks); 187 EXPECT_TRUE(CE.isEligible()); 188 189 Function *Outlined = CE.extractCodeRegion(); 190 EXPECT_TRUE(Outlined); 191 EXPECT_FALSE(verifyFunction(*Outlined, &errs())); 192 EXPECT_FALSE(verifyFunction(*Func, &errs())); 193 } 194 195 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) { 196 LLVMContext Ctx; 197 SMDiagnostic Err; 198 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 199 declare i32 @bar() 200 201 define i32 @foo() personality i8* null { 202 entry: 203 %0 = invoke i32 @bar() to label %exit unwind label %lpad 204 205 exit: 206 ret i32 %0 207 208 lpad: 209 %1 = landingpad { i8*, i32 } 210 cleanup 211 resume { i8*, i32 } %1 212 } 213 )invalid", 214 Err, Ctx)); 215 216 Function *Func = M->getFunction("foo"); 217 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"), 218 getBlockByName(Func, "lpad") }; 219 220 CodeExtractor CE(Blocks); 221 EXPECT_TRUE(CE.isEligible()); 222 223 Function *Outlined = CE.extractCodeRegion(); 224 EXPECT_TRUE(Outlined); 225 EXPECT_FALSE(verifyFunction(*Outlined)); 226 EXPECT_FALSE(verifyFunction(*Func)); 227 } 228 229 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) { 230 LLVMContext Ctx; 231 SMDiagnostic Err; 232 std::unique_ptr<Module> M(parseAssemblyString(R"ir( 233 target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" 234 target triple = "aarch64" 235 236 %b = type { i64 } 237 declare void @g(i8*) 238 239 declare void @llvm.assume(i1) #0 240 241 define void @test() { 242 entry: 243 br label %label 244 245 label: 246 %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8 247 %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0 248 %2 = load i64, i64* %1, align 8 249 %3 = icmp ugt i64 %2, 1 250 br i1 %3, label %if.then, label %if.else 251 252 if.then: 253 unreachable 254 255 if.else: 256 call void @g(i8* undef) 257 store i64 undef, i64* null, align 536870912 258 %4 = icmp eq i64 %2, 0 259 call void @llvm.assume(i1 %4) 260 unreachable 261 } 262 263 attributes #0 = { nounwind willreturn } 264 )ir", 265 Err, Ctx)); 266 267 assert(M && "Could not parse module?"); 268 Function *Func = M->getFunction("test"); 269 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") }; 270 AssumptionCache AC(*Func); 271 CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC); 272 EXPECT_TRUE(CE.isEligible()); 273 274 Function *Outlined = CE.extractCodeRegion(); 275 EXPECT_TRUE(Outlined); 276 EXPECT_FALSE(verifyFunction(*Outlined)); 277 EXPECT_FALSE(verifyFunction(*Func)); 278 EXPECT_FALSE(CE.verifyAssumptionCache(*Func, &AC)); 279 } 280 } // end anonymous namespace 281