xref: /llvm-project/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp (revision 50afaa9d34d6447b04286335d9e85bd70637ecff)
1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Dominators.h"
14 #include "llvm/IR/Instructions.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/Verifier.h"
18 #include "llvm/IRReader/IRReader.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "gtest/gtest.h"
21 
22 using namespace llvm;
23 
24 namespace {
25 BasicBlock *getBlockByName(Function *F, StringRef name) {
26   for (auto &BB : *F)
27     if (BB.getName() == name)
28       return &BB;
29   return nullptr;
30 }
31 
32 TEST(CodeExtractor, ExitStub) {
33   LLVMContext Ctx;
34   SMDiagnostic Err;
35   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
36     define i32 @foo(i32 %x, i32 %y, i32 %z) {
37     header:
38       %0 = icmp ugt i32 %x, %y
39       br i1 %0, label %body1, label %body2
40 
41     body1:
42       %1 = add i32 %z, 2
43       br label %notExtracted
44 
45     body2:
46       %2 = mul i32 %z, 7
47       br label %notExtracted
48 
49     notExtracted:
50       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
51       %4 = add i32 %3, %x
52       ret i32 %4
53     }
54   )invalid",
55                                                 Err, Ctx));
56 
57   Function *Func = M->getFunction("foo");
58   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
59                                            getBlockByName(Func, "body1"),
60                                            getBlockByName(Func, "body2") };
61 
62   CodeExtractor CE(Candidates);
63   EXPECT_TRUE(CE.isEligible());
64 
65   Function *Outlined = CE.extractCodeRegion();
66   EXPECT_TRUE(Outlined);
67   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
68   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
69   // Ensure that PHI in exit block has only one incoming value (from code
70   // replacer block).
71   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
72   // Ensure that there is a PHI in outlined function with 2 incoming values.
73   EXPECT_TRUE(ExitSplit &&
74               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
75   EXPECT_FALSE(verifyFunction(*Outlined));
76   EXPECT_FALSE(verifyFunction(*Func));
77 }
78 
79 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
80   LLVMContext Ctx;
81   SMDiagnostic Err;
82   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
83     define i32 @foo() {
84     header:
85       br i1 undef, label %extracted1, label %pred
86 
87     pred:
88       br i1 undef, label %exit1, label %exit2
89 
90     extracted1:
91       br i1 undef, label %extracted2, label %exit1
92 
93     extracted2:
94       br label %exit2
95 
96     exit1:
97       %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
98       ret i32 %0
99 
100     exit2:
101       %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
102       ret i32 %1
103     }
104   )invalid", Err, Ctx));
105 
106   Function *Func = M->getFunction("foo");
107   SmallVector<BasicBlock *, 2> ExtractedBlocks{
108     getBlockByName(Func, "extracted1"),
109     getBlockByName(Func, "extracted2")
110   };
111 
112   CodeExtractor CE(ExtractedBlocks);
113   EXPECT_TRUE(CE.isEligible());
114 
115   Function *Outlined = CE.extractCodeRegion();
116   EXPECT_TRUE(Outlined);
117   BasicBlock *Exit1 = getBlockByName(Func, "exit1");
118   BasicBlock *Exit2 = getBlockByName(Func, "exit2");
119   // Ensure that PHIs in exits are not splitted (since that they have only one
120   // incoming value from extracted region).
121   EXPECT_TRUE(Exit1 &&
122           cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
123   EXPECT_TRUE(Exit2 &&
124           cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
125   EXPECT_FALSE(verifyFunction(*Outlined));
126   EXPECT_FALSE(verifyFunction(*Func));
127 }
128 
129 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
130   LLVMContext Ctx;
131   SMDiagnostic Err;
132   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
133     declare i8 @hoge()
134 
135     define i32 @foo() personality i8* null {
136       entry:
137         %call = invoke i8 @hoge()
138                 to label %invoke.cont unwind label %lpad
139 
140       invoke.cont:                                      ; preds = %entry
141         unreachable
142 
143       lpad:                                             ; preds = %entry
144         %0 = landingpad { i8*, i32 }
145                 catch i8* null
146         br i1 undef, label %catch, label %finally.catchall
147 
148       catch:                                            ; preds = %lpad
149         %call2 = invoke i8 @hoge()
150                 to label %invoke.cont2 unwind label %lpad2
151 
152       invoke.cont2:                                    ; preds = %catch
153         %call3 = invoke i8 @hoge()
154                 to label %invoke.cont3 unwind label %lpad2
155 
156       invoke.cont3:                                    ; preds = %invoke.cont2
157         unreachable
158 
159       lpad2:                                           ; preds = %invoke.cont2, %catch
160         %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
161         %1 = landingpad { i8*, i32 }
162                 catch i8* null
163         br label %finally.catchall
164 
165       finally.catchall:                                 ; preds = %lpad33, %lpad
166         %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
167         unreachable
168     }
169   )invalid", Err, Ctx));
170 
171 	if (!M) {
172     Err.print("unit", errs());
173     exit(1);
174   }
175 
176   Function *Func = M->getFunction("foo");
177   EXPECT_FALSE(verifyFunction(*Func, &errs()));
178 
179   SmallVector<BasicBlock *, 2> ExtractedBlocks{
180     getBlockByName(Func, "catch"),
181     getBlockByName(Func, "invoke.cont2"),
182     getBlockByName(Func, "invoke.cont3"),
183     getBlockByName(Func, "lpad2")
184   };
185 
186   CodeExtractor CE(ExtractedBlocks);
187   EXPECT_TRUE(CE.isEligible());
188 
189   Function *Outlined = CE.extractCodeRegion();
190   EXPECT_TRUE(Outlined);
191   EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
192   EXPECT_FALSE(verifyFunction(*Func, &errs()));
193 }
194 
195 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
196   LLVMContext Ctx;
197   SMDiagnostic Err;
198   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
199     declare i32 @bar()
200 
201     define i32 @foo() personality i8* null {
202     entry:
203       %0 = invoke i32 @bar() to label %exit unwind label %lpad
204 
205     exit:
206       ret i32 %0
207 
208     lpad:
209       %1 = landingpad { i8*, i32 }
210               cleanup
211       resume { i8*, i32 } %1
212     }
213   )invalid",
214                                                 Err, Ctx));
215 
216   Function *Func = M->getFunction("foo");
217   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
218                                        getBlockByName(Func, "lpad") };
219 
220   CodeExtractor CE(Blocks);
221   EXPECT_TRUE(CE.isEligible());
222 
223   Function *Outlined = CE.extractCodeRegion();
224   EXPECT_TRUE(Outlined);
225   EXPECT_FALSE(verifyFunction(*Outlined));
226   EXPECT_FALSE(verifyFunction(*Func));
227 }
228 
229 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) {
230   LLVMContext Ctx;
231   SMDiagnostic Err;
232   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
233         target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
234         target triple = "aarch64"
235 
236         %b = type { i64 }
237         declare void @g(i8*)
238 
239         declare void @llvm.assume(i1) #0
240 
241         define void @test() {
242         entry:
243           br label %label
244 
245         label:
246           %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8
247           %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0
248           %2 = load i64, i64* %1, align 8
249           %3 = icmp ugt i64 %2, 1
250           br i1 %3, label %if.then, label %if.else
251 
252         if.then:
253           unreachable
254 
255         if.else:
256           call void @g(i8* undef)
257           store i64 undef, i64* null, align 536870912
258           %4 = icmp eq i64 %2, 0
259           call void @llvm.assume(i1 %4)
260           unreachable
261         }
262 
263         attributes #0 = { nounwind willreturn }
264   )ir",
265                                                 Err, Ctx));
266 
267   assert(M && "Could not parse module?");
268   Function *Func = M->getFunction("test");
269   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") };
270   AssumptionCache AC(*Func);
271   CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC);
272   EXPECT_TRUE(CE.isEligible());
273 
274   Function *Outlined = CE.extractCodeRegion();
275   EXPECT_TRUE(Outlined);
276   EXPECT_FALSE(verifyFunction(*Outlined));
277   EXPECT_FALSE(verifyFunction(*Func));
278   EXPECT_FALSE(CE.verifyAssumptionCache(*Func, &AC));
279 }
280 } // end anonymous namespace
281