xref: /llvm-project/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp (revision 0e5dd512aae057aeceb34089c93a380f8edd37da)
1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/BasicBlock.h"
12 #include "llvm/IR/Dominators.h"
13 #include "llvm/IR/Instructions.h"
14 #include "llvm/IR/LLVMContext.h"
15 #include "llvm/IR/Module.h"
16 #include "llvm/IR/Verifier.h"
17 #include "llvm/IRReader/IRReader.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "gtest/gtest.h"
20 
21 using namespace llvm;
22 
23 namespace {
24 BasicBlock *getBlockByName(Function *F, StringRef name) {
25   for (auto &BB : *F)
26     if (BB.getName() == name)
27       return &BB;
28   return nullptr;
29 }
30 
31 TEST(CodeExtractor, ExitStub) {
32   LLVMContext Ctx;
33   SMDiagnostic Err;
34   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
35     define i32 @foo(i32 %x, i32 %y, i32 %z) {
36     header:
37       %0 = icmp ugt i32 %x, %y
38       br i1 %0, label %body1, label %body2
39 
40     body1:
41       %1 = add i32 %z, 2
42       br label %notExtracted
43 
44     body2:
45       %2 = mul i32 %z, 7
46       br label %notExtracted
47 
48     notExtracted:
49       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
50       %4 = add i32 %3, %x
51       ret i32 %4
52     }
53   )invalid",
54                                                 Err, Ctx));
55 
56   Function *Func = M->getFunction("foo");
57   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
58                                            getBlockByName(Func, "body1"),
59                                            getBlockByName(Func, "body2") };
60 
61   CodeExtractor CE(Candidates);
62   EXPECT_TRUE(CE.isEligible());
63 
64   Function *Outlined = CE.extractCodeRegion();
65   EXPECT_TRUE(Outlined);
66   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
67   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
68   // Ensure that PHI in exit block has only one incoming value (from code
69   // replacer block).
70   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
71   // Ensure that there is a PHI in outlined function with 2 incoming values.
72   EXPECT_TRUE(ExitSplit &&
73               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
74   EXPECT_FALSE(verifyFunction(*Outlined));
75   EXPECT_FALSE(verifyFunction(*Func));
76 }
77 
78 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
79   LLVMContext Ctx;
80   SMDiagnostic Err;
81   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
82     define i32 @foo() {
83     header:
84       br i1 undef, label %extracted1, label %pred
85 
86     pred:
87       br i1 undef, label %exit1, label %exit2
88 
89     extracted1:
90       br i1 undef, label %extracted2, label %exit1
91 
92     extracted2:
93       br label %exit2
94 
95     exit1:
96       %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
97       ret i32 %0
98 
99     exit2:
100       %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
101       ret i32 %1
102     }
103   )invalid", Err, Ctx));
104 
105   Function *Func = M->getFunction("foo");
106   SmallVector<BasicBlock *, 2> ExtractedBlocks{
107     getBlockByName(Func, "extracted1"),
108     getBlockByName(Func, "extracted2")
109   };
110 
111   CodeExtractor CE(ExtractedBlocks);
112   EXPECT_TRUE(CE.isEligible());
113 
114   Function *Outlined = CE.extractCodeRegion();
115   EXPECT_TRUE(Outlined);
116   BasicBlock *Exit1 = getBlockByName(Func, "exit1");
117   BasicBlock *Exit2 = getBlockByName(Func, "exit2");
118   // Ensure that PHIs in exits are not splitted (since that they have only one
119   // incoming value from extracted region).
120   EXPECT_TRUE(Exit1 &&
121           cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
122   EXPECT_TRUE(Exit2 &&
123           cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
124   EXPECT_FALSE(verifyFunction(*Outlined));
125   EXPECT_FALSE(verifyFunction(*Func));
126 }
127 
128 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
129   LLVMContext Ctx;
130   SMDiagnostic Err;
131   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
132     declare i8 @hoge()
133 
134     define i32 @foo() personality i8* null {
135       entry:
136         %call = invoke i8 @hoge()
137                 to label %invoke.cont unwind label %lpad
138 
139       invoke.cont:                                      ; preds = %entry
140         unreachable
141 
142       lpad:                                             ; preds = %entry
143         %0 = landingpad { i8*, i32 }
144                 catch i8* null
145         br i1 undef, label %catch, label %finally.catchall
146 
147       catch:                                            ; preds = %lpad
148         %call2 = invoke i8 @hoge()
149                 to label %invoke.cont2 unwind label %lpad2
150 
151       invoke.cont2:                                    ; preds = %catch
152         %call3 = invoke i8 @hoge()
153                 to label %invoke.cont3 unwind label %lpad2
154 
155       invoke.cont3:                                    ; preds = %invoke.cont2
156         unreachable
157 
158       lpad2:                                           ; preds = %invoke.cont2, %catch
159         %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
160         %1 = landingpad { i8*, i32 }
161                 catch i8* null
162         br label %finally.catchall
163 
164       finally.catchall:                                 ; preds = %lpad33, %lpad
165         %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
166         unreachable
167     }
168   )invalid", Err, Ctx));
169 
170 	if (!M) {
171     Err.print("unit", errs());
172     exit(1);
173   }
174 
175   Function *Func = M->getFunction("foo");
176   EXPECT_FALSE(verifyFunction(*Func, &errs()));
177 
178   SmallVector<BasicBlock *, 2> ExtractedBlocks{
179     getBlockByName(Func, "catch"),
180     getBlockByName(Func, "invoke.cont2"),
181     getBlockByName(Func, "invoke.cont3"),
182     getBlockByName(Func, "lpad2")
183   };
184 
185   CodeExtractor CE(ExtractedBlocks);
186   EXPECT_TRUE(CE.isEligible());
187 
188   Function *Outlined = CE.extractCodeRegion();
189   EXPECT_TRUE(Outlined);
190   EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
191   EXPECT_FALSE(verifyFunction(*Func, &errs()));
192 }
193 
194 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
195   LLVMContext Ctx;
196   SMDiagnostic Err;
197   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
198     declare i32 @bar()
199 
200     define i32 @foo() personality i8* null {
201     entry:
202       %0 = invoke i32 @bar() to label %exit unwind label %lpad
203 
204     exit:
205       ret i32 %0
206 
207     lpad:
208       %1 = landingpad { i8*, i32 }
209               cleanup
210       resume { i8*, i32 } %1
211     }
212   )invalid",
213                                                 Err, Ctx));
214 
215   Function *Func = M->getFunction("foo");
216   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
217                                        getBlockByName(Func, "lpad") };
218 
219   CodeExtractor CE(Blocks);
220   EXPECT_TRUE(CE.isEligible());
221 
222   Function *Outlined = CE.extractCodeRegion();
223   EXPECT_TRUE(Outlined);
224   EXPECT_FALSE(verifyFunction(*Outlined));
225   EXPECT_FALSE(verifyFunction(*Func));
226 }
227 
228 } // end anonymous namespace
229