1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Utils/CodeExtractor.h" 10 #include "llvm/AsmParser/Parser.h" 11 #include "llvm/IR/BasicBlock.h" 12 #include "llvm/IR/Dominators.h" 13 #include "llvm/IR/Instructions.h" 14 #include "llvm/IR/LLVMContext.h" 15 #include "llvm/IR/Module.h" 16 #include "llvm/IR/Verifier.h" 17 #include "llvm/IRReader/IRReader.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "gtest/gtest.h" 20 21 using namespace llvm; 22 23 namespace { 24 BasicBlock *getBlockByName(Function *F, StringRef name) { 25 for (auto &BB : *F) 26 if (BB.getName() == name) 27 return &BB; 28 return nullptr; 29 } 30 31 TEST(CodeExtractor, ExitStub) { 32 LLVMContext Ctx; 33 SMDiagnostic Err; 34 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 35 define i32 @foo(i32 %x, i32 %y, i32 %z) { 36 header: 37 %0 = icmp ugt i32 %x, %y 38 br i1 %0, label %body1, label %body2 39 40 body1: 41 %1 = add i32 %z, 2 42 br label %notExtracted 43 44 body2: 45 %2 = mul i32 %z, 7 46 br label %notExtracted 47 48 notExtracted: 49 %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ] 50 %4 = add i32 %3, %x 51 ret i32 %4 52 } 53 )invalid", 54 Err, Ctx)); 55 56 Function *Func = M->getFunction("foo"); 57 SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"), 58 getBlockByName(Func, "body1"), 59 getBlockByName(Func, "body2") }; 60 61 CodeExtractor CE(Candidates); 62 EXPECT_TRUE(CE.isEligible()); 63 64 Function *Outlined = CE.extractCodeRegion(); 65 EXPECT_TRUE(Outlined); 66 BasicBlock *Exit = getBlockByName(Func, "notExtracted"); 67 BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split"); 68 // Ensure that PHI in exit block has only one incoming value (from code 69 // replacer block). 70 EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1); 71 // Ensure that there is a PHI in outlined function with 2 incoming values. 72 EXPECT_TRUE(ExitSplit && 73 cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2); 74 EXPECT_FALSE(verifyFunction(*Outlined)); 75 EXPECT_FALSE(verifyFunction(*Func)); 76 } 77 78 TEST(CodeExtractor, ExitPHIOnePredFromRegion) { 79 LLVMContext Ctx; 80 SMDiagnostic Err; 81 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 82 define i32 @foo() { 83 header: 84 br i1 undef, label %extracted1, label %pred 85 86 pred: 87 br i1 undef, label %exit1, label %exit2 88 89 extracted1: 90 br i1 undef, label %extracted2, label %exit1 91 92 extracted2: 93 br label %exit2 94 95 exit1: 96 %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ] 97 ret i32 %0 98 99 exit2: 100 %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ] 101 ret i32 %1 102 } 103 )invalid", Err, Ctx)); 104 105 Function *Func = M->getFunction("foo"); 106 SmallVector<BasicBlock *, 2> ExtractedBlocks{ 107 getBlockByName(Func, "extracted1"), 108 getBlockByName(Func, "extracted2") 109 }; 110 111 CodeExtractor CE(ExtractedBlocks); 112 EXPECT_TRUE(CE.isEligible()); 113 114 Function *Outlined = CE.extractCodeRegion(); 115 EXPECT_TRUE(Outlined); 116 BasicBlock *Exit1 = getBlockByName(Func, "exit1"); 117 BasicBlock *Exit2 = getBlockByName(Func, "exit2"); 118 // Ensure that PHIs in exits are not splitted (since that they have only one 119 // incoming value from extracted region). 120 EXPECT_TRUE(Exit1 && 121 cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2); 122 EXPECT_TRUE(Exit2 && 123 cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2); 124 EXPECT_FALSE(verifyFunction(*Outlined)); 125 EXPECT_FALSE(verifyFunction(*Func)); 126 } 127 128 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) { 129 LLVMContext Ctx; 130 SMDiagnostic Err; 131 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 132 declare i8 @hoge() 133 134 define i32 @foo() personality i8* null { 135 entry: 136 %call = invoke i8 @hoge() 137 to label %invoke.cont unwind label %lpad 138 139 invoke.cont: ; preds = %entry 140 unreachable 141 142 lpad: ; preds = %entry 143 %0 = landingpad { i8*, i32 } 144 catch i8* null 145 br i1 undef, label %catch, label %finally.catchall 146 147 catch: ; preds = %lpad 148 %call2 = invoke i8 @hoge() 149 to label %invoke.cont2 unwind label %lpad2 150 151 invoke.cont2: ; preds = %catch 152 %call3 = invoke i8 @hoge() 153 to label %invoke.cont3 unwind label %lpad2 154 155 invoke.cont3: ; preds = %invoke.cont2 156 unreachable 157 158 lpad2: ; preds = %invoke.cont2, %catch 159 %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ] 160 %1 = landingpad { i8*, i32 } 161 catch i8* null 162 br label %finally.catchall 163 164 finally.catchall: ; preds = %lpad33, %lpad 165 %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ] 166 unreachable 167 } 168 )invalid", Err, Ctx)); 169 170 if (!M) { 171 Err.print("unit", errs()); 172 exit(1); 173 } 174 175 Function *Func = M->getFunction("foo"); 176 EXPECT_FALSE(verifyFunction(*Func, &errs())); 177 178 SmallVector<BasicBlock *, 2> ExtractedBlocks{ 179 getBlockByName(Func, "catch"), 180 getBlockByName(Func, "invoke.cont2"), 181 getBlockByName(Func, "invoke.cont3"), 182 getBlockByName(Func, "lpad2") 183 }; 184 185 CodeExtractor CE(ExtractedBlocks); 186 EXPECT_TRUE(CE.isEligible()); 187 188 Function *Outlined = CE.extractCodeRegion(); 189 EXPECT_TRUE(Outlined); 190 EXPECT_FALSE(verifyFunction(*Outlined, &errs())); 191 EXPECT_FALSE(verifyFunction(*Func, &errs())); 192 } 193 194 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) { 195 LLVMContext Ctx; 196 SMDiagnostic Err; 197 std::unique_ptr<Module> M(parseAssemblyString(R"invalid( 198 declare i32 @bar() 199 200 define i32 @foo() personality i8* null { 201 entry: 202 %0 = invoke i32 @bar() to label %exit unwind label %lpad 203 204 exit: 205 ret i32 %0 206 207 lpad: 208 %1 = landingpad { i8*, i32 } 209 cleanup 210 resume { i8*, i32 } %1 211 } 212 )invalid", 213 Err, Ctx)); 214 215 Function *Func = M->getFunction("foo"); 216 SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"), 217 getBlockByName(Func, "lpad") }; 218 219 CodeExtractor CE(Blocks); 220 EXPECT_TRUE(CE.isEligible()); 221 222 Function *Outlined = CE.extractCodeRegion(); 223 EXPECT_TRUE(Outlined); 224 EXPECT_FALSE(verifyFunction(*Outlined)); 225 EXPECT_FALSE(verifyFunction(*Func)); 226 } 227 228 } // end anonymous namespace 229