xref: /llvm-project/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp (revision 95b981ca2ae3915464a63d42eb53b0dde4a88227)
1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Dominators.h"
14 #include "llvm/IR/Instructions.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/Verifier.h"
18 #include "llvm/IRReader/IRReader.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "gtest/gtest.h"
21 
22 using namespace llvm;
23 
24 namespace {
25 BasicBlock *getBlockByName(Function *F, StringRef name) {
26   for (auto &BB : *F)
27     if (BB.getName() == name)
28       return &BB;
29   return nullptr;
30 }
31 
32 TEST(CodeExtractor, ExitStub) {
33   LLVMContext Ctx;
34   SMDiagnostic Err;
35   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
36     define i32 @foo(i32 %x, i32 %y, i32 %z) {
37     header:
38       %0 = icmp ugt i32 %x, %y
39       br i1 %0, label %body1, label %body2
40 
41     body1:
42       %1 = add i32 %z, 2
43       br label %notExtracted
44 
45     body2:
46       %2 = mul i32 %z, 7
47       br label %notExtracted
48 
49     notExtracted:
50       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
51       %4 = add i32 %3, %x
52       ret i32 %4
53     }
54   )invalid",
55                                                 Err, Ctx));
56 
57   Function *Func = M->getFunction("foo");
58   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
59                                            getBlockByName(Func, "body1"),
60                                            getBlockByName(Func, "body2") };
61 
62   CodeExtractor CE(Candidates);
63   EXPECT_TRUE(CE.isEligible());
64 
65   CodeExtractorAnalysisCache CEAC(*Func);
66   Function *Outlined = CE.extractCodeRegion(CEAC);
67   EXPECT_TRUE(Outlined);
68   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
69   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
70   // Ensure that PHI in exit block has only one incoming value (from code
71   // replacer block).
72   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
73   // Ensure that there is a PHI in outlined function with 2 incoming values.
74   EXPECT_TRUE(ExitSplit &&
75               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
76   EXPECT_FALSE(verifyFunction(*Outlined));
77   EXPECT_FALSE(verifyFunction(*Func));
78 }
79 
80 TEST(CodeExtractor, InputOutputMonitoring) {
81   LLVMContext Ctx;
82   SMDiagnostic Err;
83   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
84     define i32 @foo(i32 %x, i32 %y, i32 %z) {
85     header:
86       %0 = icmp ugt i32 %x, %y
87       br i1 %0, label %body1, label %body2
88 
89     body1:
90       %1 = add i32 %z, 2
91       br label %notExtracted
92 
93     body2:
94       %2 = mul i32 %z, 7
95       br label %notExtracted
96 
97     notExtracted:
98       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
99       %4 = add i32 %3, %x
100       ret i32 %4
101     }
102   )invalid",
103                                                 Err, Ctx));
104 
105   Function *Func = M->getFunction("foo");
106   SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"),
107                                           getBlockByName(Func, "body1"),
108                                           getBlockByName(Func, "body2")};
109 
110   CodeExtractor CE(Candidates);
111   EXPECT_TRUE(CE.isEligible());
112 
113   CodeExtractorAnalysisCache CEAC(*Func);
114   SetVector<Value *> Inputs, Outputs;
115   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
116   EXPECT_TRUE(Outlined);
117 
118   EXPECT_EQ(Inputs.size(), 3u);
119   EXPECT_EQ(Inputs[0], Func->getArg(2));
120   EXPECT_EQ(Inputs[1], Func->getArg(0));
121   EXPECT_EQ(Inputs[2], Func->getArg(1));
122   EXPECT_EQ(Outputs.size(), 1u);
123   StoreInst *SI = cast<StoreInst>(Outlined->getArg(3)->user_back());
124   Value *OutputVal = SI->getValueOperand();
125   EXPECT_EQ(Outputs[0], OutputVal);
126   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
127   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
128   // Ensure that PHI in exit block has only one incoming value (from code
129   // replacer block).
130   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
131   // Ensure that there is a PHI in outlined function with 2 incoming values.
132   EXPECT_TRUE(ExitSplit &&
133               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
134   EXPECT_FALSE(verifyFunction(*Outlined));
135   EXPECT_FALSE(verifyFunction(*Func));
136 }
137 
138 TEST(CodeExtractor, ExitBlockOrderingPhis) {
139   LLVMContext Ctx;
140   SMDiagnostic Err;
141   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
142     define void @foo(i32 %a, i32 %b) {
143     entry:
144       %0 = alloca i32, align 4
145       br label %test0
146     test0:
147       %c = load i32, i32* %0, align 4
148       br label %test1
149     test1:
150       %e = load i32, i32* %0, align 4
151       br i1 true, label %first, label %test
152     test:
153       %d = load i32, i32* %0, align 4
154       br i1 true, label %first, label %next
155     first:
156       %1 = phi i32 [ %c, %test ], [ %e, %test1 ]
157       ret void
158     next:
159       %2 = add i32 %d, 1
160       %3 = add i32 %e, 1
161       ret void
162     }
163   )invalid",
164                                                 Err, Ctx));
165   Function *Func = M->getFunction("foo");
166   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
167                                            getBlockByName(Func, "test1"),
168                                            getBlockByName(Func, "test") };
169 
170   CodeExtractor CE(Candidates);
171   EXPECT_TRUE(CE.isEligible());
172 
173   CodeExtractorAnalysisCache CEAC(*Func);
174   Function *Outlined = CE.extractCodeRegion(CEAC);
175   EXPECT_TRUE(Outlined);
176 
177   BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
178   BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
179 
180   Instruction *FirstTerm = FirstExitStub->getTerminator();
181   ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
182   EXPECT_TRUE(FirstReturn);
183   ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
184   EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
185 
186   Instruction *NextTerm = NextExitStub->getTerminator();
187   ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
188   EXPECT_TRUE(NextReturn);
189   ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
190   EXPECT_TRUE(CINext->getLimitedValue() == 0u);
191 
192   EXPECT_FALSE(verifyFunction(*Outlined));
193   EXPECT_FALSE(verifyFunction(*Func));
194 }
195 
196 TEST(CodeExtractor, ExitBlockOrdering) {
197   LLVMContext Ctx;
198   SMDiagnostic Err;
199   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
200     define void @foo(i32 %a, i32 %b) {
201     entry:
202       %0 = alloca i32, align 4
203       br label %test0
204     test0:
205       %c = load i32, i32* %0, align 4
206       br label %test1
207     test1:
208       %e = load i32, i32* %0, align 4
209       br i1 true, label %first, label %test
210     test:
211       %d = load i32, i32* %0, align 4
212       br i1 true, label %first, label %next
213     first:
214       ret void
215     next:
216       %1 = add i32 %d, 1
217       %2 = add i32 %e, 1
218       ret void
219     }
220   )invalid",
221                                                 Err, Ctx));
222   Function *Func = M->getFunction("foo");
223   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
224                                            getBlockByName(Func, "test1"),
225                                            getBlockByName(Func, "test") };
226 
227   CodeExtractor CE(Candidates);
228   EXPECT_TRUE(CE.isEligible());
229 
230   CodeExtractorAnalysisCache CEAC(*Func);
231   Function *Outlined = CE.extractCodeRegion(CEAC);
232   EXPECT_TRUE(Outlined);
233 
234   BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
235   BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
236 
237   Instruction *FirstTerm = FirstExitStub->getTerminator();
238   ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
239   EXPECT_TRUE(FirstReturn);
240   ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
241   EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
242 
243   Instruction *NextTerm = NextExitStub->getTerminator();
244   ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
245   EXPECT_TRUE(NextReturn);
246   ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
247   EXPECT_TRUE(CINext->getLimitedValue() == 0u);
248 
249   EXPECT_FALSE(verifyFunction(*Outlined));
250   EXPECT_FALSE(verifyFunction(*Func));
251 }
252 
253 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
254   LLVMContext Ctx;
255   SMDiagnostic Err;
256   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
257     define i32 @foo() {
258     header:
259       br i1 undef, label %extracted1, label %pred
260 
261     pred:
262       br i1 undef, label %exit1, label %exit2
263 
264     extracted1:
265       br i1 undef, label %extracted2, label %exit1
266 
267     extracted2:
268       br label %exit2
269 
270     exit1:
271       %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
272       ret i32 %0
273 
274     exit2:
275       %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
276       ret i32 %1
277     }
278   )invalid", Err, Ctx));
279 
280   Function *Func = M->getFunction("foo");
281   SmallVector<BasicBlock *, 2> ExtractedBlocks{
282     getBlockByName(Func, "extracted1"),
283     getBlockByName(Func, "extracted2")
284   };
285 
286   CodeExtractor CE(ExtractedBlocks);
287   EXPECT_TRUE(CE.isEligible());
288 
289   CodeExtractorAnalysisCache CEAC(*Func);
290   Function *Outlined = CE.extractCodeRegion(CEAC);
291   EXPECT_TRUE(Outlined);
292   BasicBlock *Exit1 = getBlockByName(Func, "exit1");
293   BasicBlock *Exit2 = getBlockByName(Func, "exit2");
294   // Ensure that PHIs in exits are not splitted (since that they have only one
295   // incoming value from extracted region).
296   EXPECT_TRUE(Exit1 &&
297           cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
298   EXPECT_TRUE(Exit2 &&
299           cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
300   EXPECT_FALSE(verifyFunction(*Outlined));
301   EXPECT_FALSE(verifyFunction(*Func));
302 }
303 
304 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
305   LLVMContext Ctx;
306   SMDiagnostic Err;
307   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
308     declare i8 @hoge()
309 
310     define i32 @foo() personality i8* null {
311       entry:
312         %call = invoke i8 @hoge()
313                 to label %invoke.cont unwind label %lpad
314 
315       invoke.cont:                                      ; preds = %entry
316         unreachable
317 
318       lpad:                                             ; preds = %entry
319         %0 = landingpad { i8*, i32 }
320                 catch i8* null
321         br i1 undef, label %catch, label %finally.catchall
322 
323       catch:                                            ; preds = %lpad
324         %call2 = invoke i8 @hoge()
325                 to label %invoke.cont2 unwind label %lpad2
326 
327       invoke.cont2:                                    ; preds = %catch
328         %call3 = invoke i8 @hoge()
329                 to label %invoke.cont3 unwind label %lpad2
330 
331       invoke.cont3:                                    ; preds = %invoke.cont2
332         unreachable
333 
334       lpad2:                                           ; preds = %invoke.cont2, %catch
335         %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
336         %1 = landingpad { i8*, i32 }
337                 catch i8* null
338         br label %finally.catchall
339 
340       finally.catchall:                                 ; preds = %lpad33, %lpad
341         %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
342         unreachable
343     }
344   )invalid", Err, Ctx));
345 
346 	if (!M) {
347     Err.print("unit", errs());
348     exit(1);
349   }
350 
351   Function *Func = M->getFunction("foo");
352   EXPECT_FALSE(verifyFunction(*Func, &errs()));
353 
354   SmallVector<BasicBlock *, 2> ExtractedBlocks{
355     getBlockByName(Func, "catch"),
356     getBlockByName(Func, "invoke.cont2"),
357     getBlockByName(Func, "invoke.cont3"),
358     getBlockByName(Func, "lpad2")
359   };
360 
361   CodeExtractor CE(ExtractedBlocks);
362   EXPECT_TRUE(CE.isEligible());
363 
364   CodeExtractorAnalysisCache CEAC(*Func);
365   Function *Outlined = CE.extractCodeRegion(CEAC);
366   EXPECT_TRUE(Outlined);
367   EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
368   EXPECT_FALSE(verifyFunction(*Func, &errs()));
369 }
370 
371 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
372   LLVMContext Ctx;
373   SMDiagnostic Err;
374   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
375     declare i32 @bar()
376 
377     define i32 @foo() personality i8* null {
378     entry:
379       %0 = invoke i32 @bar() to label %exit unwind label %lpad
380 
381     exit:
382       ret i32 %0
383 
384     lpad:
385       %1 = landingpad { i8*, i32 }
386               cleanup
387       resume { i8*, i32 } %1
388     }
389   )invalid",
390                                                 Err, Ctx));
391 
392   Function *Func = M->getFunction("foo");
393   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
394                                        getBlockByName(Func, "lpad") };
395 
396   CodeExtractor CE(Blocks);
397   EXPECT_TRUE(CE.isEligible());
398 
399   CodeExtractorAnalysisCache CEAC(*Func);
400   Function *Outlined = CE.extractCodeRegion(CEAC);
401   EXPECT_TRUE(Outlined);
402   EXPECT_FALSE(verifyFunction(*Outlined));
403   EXPECT_FALSE(verifyFunction(*Func));
404 }
405 
406 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) {
407   LLVMContext Ctx;
408   SMDiagnostic Err;
409   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
410         target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
411         target triple = "aarch64"
412 
413         %b = type { i64 }
414         declare void @g(i8*)
415 
416         declare void @llvm.assume(i1) #0
417 
418         define void @test() {
419         entry:
420           br label %label
421 
422         label:
423           %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8
424           %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0
425           %2 = load i64, i64* %1, align 8
426           %3 = icmp ugt i64 %2, 1
427           br i1 %3, label %if.then, label %if.else
428 
429         if.then:
430           unreachable
431 
432         if.else:
433           call void @g(i8* undef)
434           store i64 undef, i64* null, align 536870912
435           %4 = icmp eq i64 %2, 0
436           call void @llvm.assume(i1 %4)
437           unreachable
438         }
439 
440         attributes #0 = { nounwind willreturn }
441   )ir",
442                                                 Err, Ctx));
443 
444   assert(M && "Could not parse module?");
445   Function *Func = M->getFunction("test");
446   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") };
447   AssumptionCache AC(*Func);
448   CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC);
449   EXPECT_TRUE(CE.isEligible());
450 
451   CodeExtractorAnalysisCache CEAC(*Func);
452   Function *Outlined = CE.extractCodeRegion(CEAC);
453   EXPECT_TRUE(Outlined);
454   EXPECT_FALSE(verifyFunction(*Outlined));
455   EXPECT_FALSE(verifyFunction(*Func));
456   EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC));
457 }
458 
459 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
460   LLVMContext Ctx;
461   SMDiagnostic Err;
462   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
463     target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
464     target triple = "x86_64-unknown-linux-gnu"
465 
466     declare void @use(i32*)
467     declare void @llvm.lifetime.start.p0i8(i64, i8*)
468     declare void @llvm.lifetime.end.p0i8(i64, i8*)
469 
470     define void @foo() {
471     entry:
472       %0 = alloca i32
473       br label %extract
474 
475     extract:
476       %1 = bitcast i32* %0 to i8*
477       call void @llvm.lifetime.start.p0i8(i64 4, i8* %1)
478       call void @use(i32* %0)
479       br label %exit
480 
481     exit:
482       call void @use(i32* %0)
483       call void @llvm.lifetime.end.p0i8(i64 4, i8* %1)
484       ret void
485     }
486   )ir",
487                                                 Err, Ctx));
488 
489   Function *Func = M->getFunction("foo");
490   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
491 
492   CodeExtractor CE(Blocks);
493   EXPECT_TRUE(CE.isEligible());
494 
495   CodeExtractorAnalysisCache CEAC(*Func);
496   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
497   BasicBlock *CommonExit = nullptr;
498   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
499   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
500   EXPECT_EQ(Outputs.size(), 0U);
501 
502   Function *Outlined = CE.extractCodeRegion(CEAC);
503   EXPECT_TRUE(Outlined);
504   EXPECT_FALSE(verifyFunction(*Outlined));
505   EXPECT_FALSE(verifyFunction(*Func));
506 }
507 
508 TEST(CodeExtractor, PartialAggregateArgs) {
509   LLVMContext Ctx;
510   SMDiagnostic Err;
511   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
512     target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
513     target triple = "x86_64-unknown-linux-gnu"
514 
515     declare void @use(i32)
516 
517     define void @foo(i32 %a, i32 %b, i32 %c) {
518     entry:
519       br label %extract
520 
521     extract:
522       call void @use(i32 %a)
523       call void @use(i32 %b)
524       call void @use(i32 %c)
525       br label %exit
526 
527     exit:
528       ret void
529     }
530   )ir",
531                                                 Err, Ctx));
532 
533   Function *Func = M->getFunction("foo");
534   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
535 
536   // Create the CodeExtractor with arguments aggregation enabled.
537   CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
538                    /* AggregateArgs */ true);
539   EXPECT_TRUE(CE.isEligible());
540 
541   CodeExtractorAnalysisCache CEAC(*Func);
542   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
543   BasicBlock *CommonExit = nullptr;
544   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
545   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
546   // Exclude the first input from the argument aggregate.
547   CE.excludeArgFromAggregate(Inputs[0]);
548 
549   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
550   EXPECT_TRUE(Outlined);
551   // Expect 2 arguments in the outlined function: the excluded input and the
552   // struct aggregate for the remaining inputs.
553   EXPECT_EQ(Outlined->arg_size(), 2U);
554   EXPECT_FALSE(verifyFunction(*Outlined));
555   EXPECT_FALSE(verifyFunction(*Func));
556 }
557 } // end anonymous namespace
558