xref: /llvm-project/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp (revision f6795e6b4f619cbecc59a92f7e5fad7ca90ece54)
1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/IR/LLVMContext.h"
18 #include "llvm/IR/Module.h"
19 #include "llvm/IR/Verifier.h"
20 #include "llvm/IRReader/IRReader.h"
21 #include "llvm/Support/SourceMgr.h"
22 #include "gtest/gtest.h"
23 
24 using namespace llvm;
25 
26 namespace {
27 BasicBlock *getBlockByName(Function *F, StringRef name) {
28   for (auto &BB : *F)
29     if (BB.getName() == name)
30       return &BB;
31   return nullptr;
32 }
33 
34 Instruction *getInstByName(Function *F, StringRef Name) {
35   for (Instruction &I : instructions(F))
36     if (I.getName() == Name)
37       return &I;
38   return nullptr;
39 }
40 
41 TEST(CodeExtractor, ExitStub) {
42   LLVMContext Ctx;
43   SMDiagnostic Err;
44   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
45     define i32 @foo(i32 %x, i32 %y, i32 %z) {
46     header:
47       %0 = icmp ugt i32 %x, %y
48       br i1 %0, label %body1, label %body2
49 
50     body1:
51       %1 = add i32 %z, 2
52       br label %notExtracted
53 
54     body2:
55       %2 = mul i32 %z, 7
56       br label %notExtracted
57 
58     notExtracted:
59       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
60       %4 = add i32 %3, %x
61       ret i32 %4
62     }
63   )invalid",
64                                                 Err, Ctx));
65 
66   Function *Func = M->getFunction("foo");
67   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
68                                            getBlockByName(Func, "body1"),
69                                            getBlockByName(Func, "body2") };
70 
71   CodeExtractor CE(Candidates);
72   EXPECT_TRUE(CE.isEligible());
73 
74   CodeExtractorAnalysisCache CEAC(*Func);
75   Function *Outlined = CE.extractCodeRegion(CEAC);
76   EXPECT_TRUE(Outlined);
77   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
78   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
79   // Ensure that PHI in exit block has only one incoming value (from code
80   // replacer block).
81   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
82   // Ensure that there is a PHI in outlined function with 2 incoming values.
83   EXPECT_TRUE(ExitSplit &&
84               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
85   EXPECT_FALSE(verifyFunction(*Outlined));
86   EXPECT_FALSE(verifyFunction(*Func));
87 }
88 
89 TEST(CodeExtractor, InputOutputMonitoring) {
90   LLVMContext Ctx;
91   SMDiagnostic Err;
92   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
93     define i32 @foo(i32 %x, i32 %y, i32 %z) {
94     header:
95       %0 = icmp ugt i32 %x, %y
96       br i1 %0, label %body1, label %body2
97 
98     body1:
99       %1 = add i32 %z, 2
100       br label %notExtracted
101 
102     body2:
103       %2 = mul i32 %z, 7
104       br label %notExtracted
105 
106     notExtracted:
107       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
108       %4 = add i32 %3, %x
109       ret i32 %4
110     }
111   )invalid",
112                                                 Err, Ctx));
113 
114   Function *Func = M->getFunction("foo");
115   SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"),
116                                           getBlockByName(Func, "body1"),
117                                           getBlockByName(Func, "body2")};
118 
119   CodeExtractor CE(Candidates);
120   EXPECT_TRUE(CE.isEligible());
121 
122   CodeExtractorAnalysisCache CEAC(*Func);
123   SetVector<Value *> Inputs, Outputs;
124   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
125   EXPECT_TRUE(Outlined);
126 
127   EXPECT_EQ(Inputs.size(), 3u);
128   EXPECT_EQ(Inputs[0], Func->getArg(2));
129   EXPECT_EQ(Inputs[1], Func->getArg(0));
130   EXPECT_EQ(Inputs[2], Func->getArg(1));
131   EXPECT_EQ(Outputs.size(), 1u);
132   StoreInst *SI = cast<StoreInst>(Outlined->getArg(3)->user_back());
133   Value *OutputVal = SI->getValueOperand();
134   EXPECT_EQ(Outputs[0], OutputVal);
135   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
136   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
137   // Ensure that PHI in exit block has only one incoming value (from code
138   // replacer block).
139   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
140   // Ensure that there is a PHI in outlined function with 2 incoming values.
141   EXPECT_TRUE(ExitSplit &&
142               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
143   EXPECT_FALSE(verifyFunction(*Outlined));
144   EXPECT_FALSE(verifyFunction(*Func));
145 }
146 
147 TEST(CodeExtractor, ExitBlockOrderingPhis) {
148   LLVMContext Ctx;
149   SMDiagnostic Err;
150   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
151     define void @foo(i32 %a, i32 %b) {
152     entry:
153       %0 = alloca i32, align 4
154       br label %test0
155     test0:
156       %c = load i32, i32* %0, align 4
157       br label %test1
158     test1:
159       %e = load i32, i32* %0, align 4
160       br i1 true, label %first, label %test
161     test:
162       %d = load i32, i32* %0, align 4
163       br i1 true, label %first, label %next
164     first:
165       %1 = phi i32 [ %c, %test ], [ %e, %test1 ]
166       ret void
167     next:
168       %2 = add i32 %d, 1
169       %3 = add i32 %e, 1
170       ret void
171     }
172   )invalid",
173                                                 Err, Ctx));
174   Function *Func = M->getFunction("foo");
175   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
176                                            getBlockByName(Func, "test1"),
177                                            getBlockByName(Func, "test") };
178 
179   CodeExtractor CE(Candidates);
180   EXPECT_TRUE(CE.isEligible());
181 
182   CodeExtractorAnalysisCache CEAC(*Func);
183   Function *Outlined = CE.extractCodeRegion(CEAC);
184   EXPECT_TRUE(Outlined);
185 
186   BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
187   BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
188 
189   Instruction *FirstTerm = FirstExitStub->getTerminator();
190   ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
191   EXPECT_TRUE(FirstReturn);
192   ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
193   EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
194 
195   Instruction *NextTerm = NextExitStub->getTerminator();
196   ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
197   EXPECT_TRUE(NextReturn);
198   ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
199   EXPECT_TRUE(CINext->getLimitedValue() == 0u);
200 
201   EXPECT_FALSE(verifyFunction(*Outlined));
202   EXPECT_FALSE(verifyFunction(*Func));
203 }
204 
205 TEST(CodeExtractor, ExitBlockOrdering) {
206   LLVMContext Ctx;
207   SMDiagnostic Err;
208   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
209     define void @foo(i32 %a, i32 %b) {
210     entry:
211       %0 = alloca i32, align 4
212       br label %test0
213     test0:
214       %c = load i32, i32* %0, align 4
215       br label %test1
216     test1:
217       %e = load i32, i32* %0, align 4
218       br i1 true, label %first, label %test
219     test:
220       %d = load i32, i32* %0, align 4
221       br i1 true, label %first, label %next
222     first:
223       ret void
224     next:
225       %1 = add i32 %d, 1
226       %2 = add i32 %e, 1
227       ret void
228     }
229   )invalid",
230                                                 Err, Ctx));
231   Function *Func = M->getFunction("foo");
232   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
233                                            getBlockByName(Func, "test1"),
234                                            getBlockByName(Func, "test") };
235 
236   CodeExtractor CE(Candidates);
237   EXPECT_TRUE(CE.isEligible());
238 
239   CodeExtractorAnalysisCache CEAC(*Func);
240   Function *Outlined = CE.extractCodeRegion(CEAC);
241   EXPECT_TRUE(Outlined);
242 
243   BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
244   BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
245 
246   Instruction *FirstTerm = FirstExitStub->getTerminator();
247   ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
248   EXPECT_TRUE(FirstReturn);
249   ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
250   EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
251 
252   Instruction *NextTerm = NextExitStub->getTerminator();
253   ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
254   EXPECT_TRUE(NextReturn);
255   ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
256   EXPECT_TRUE(CINext->getLimitedValue() == 0u);
257 
258   EXPECT_FALSE(verifyFunction(*Outlined));
259   EXPECT_FALSE(verifyFunction(*Func));
260 }
261 
262 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
263   LLVMContext Ctx;
264   SMDiagnostic Err;
265   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
266     define i32 @foo() {
267     header:
268       br i1 undef, label %extracted1, label %pred
269 
270     pred:
271       br i1 undef, label %exit1, label %exit2
272 
273     extracted1:
274       br i1 undef, label %extracted2, label %exit1
275 
276     extracted2:
277       br label %exit2
278 
279     exit1:
280       %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
281       ret i32 %0
282 
283     exit2:
284       %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
285       ret i32 %1
286     }
287   )invalid", Err, Ctx));
288 
289   Function *Func = M->getFunction("foo");
290   SmallVector<BasicBlock *, 2> ExtractedBlocks{
291     getBlockByName(Func, "extracted1"),
292     getBlockByName(Func, "extracted2")
293   };
294 
295   CodeExtractor CE(ExtractedBlocks);
296   EXPECT_TRUE(CE.isEligible());
297 
298   CodeExtractorAnalysisCache CEAC(*Func);
299   Function *Outlined = CE.extractCodeRegion(CEAC);
300   EXPECT_TRUE(Outlined);
301   BasicBlock *Exit1 = getBlockByName(Func, "exit1");
302   BasicBlock *Exit2 = getBlockByName(Func, "exit2");
303   // Ensure that PHIs in exits are not splitted (since that they have only one
304   // incoming value from extracted region).
305   EXPECT_TRUE(Exit1 &&
306           cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
307   EXPECT_TRUE(Exit2 &&
308           cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
309   EXPECT_FALSE(verifyFunction(*Outlined));
310   EXPECT_FALSE(verifyFunction(*Func));
311 }
312 
313 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
314   LLVMContext Ctx;
315   SMDiagnostic Err;
316   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
317     declare i8 @hoge()
318 
319     define i32 @foo() personality i8* null {
320       entry:
321         %call = invoke i8 @hoge()
322                 to label %invoke.cont unwind label %lpad
323 
324       invoke.cont:                                      ; preds = %entry
325         unreachable
326 
327       lpad:                                             ; preds = %entry
328         %0 = landingpad { i8*, i32 }
329                 catch i8* null
330         br i1 undef, label %catch, label %finally.catchall
331 
332       catch:                                            ; preds = %lpad
333         %call2 = invoke i8 @hoge()
334                 to label %invoke.cont2 unwind label %lpad2
335 
336       invoke.cont2:                                    ; preds = %catch
337         %call3 = invoke i8 @hoge()
338                 to label %invoke.cont3 unwind label %lpad2
339 
340       invoke.cont3:                                    ; preds = %invoke.cont2
341         unreachable
342 
343       lpad2:                                           ; preds = %invoke.cont2, %catch
344         %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
345         %1 = landingpad { i8*, i32 }
346                 catch i8* null
347         br label %finally.catchall
348 
349       finally.catchall:                                 ; preds = %lpad33, %lpad
350         %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
351         unreachable
352     }
353   )invalid", Err, Ctx));
354 
355 	if (!M) {
356     Err.print("unit", errs());
357     exit(1);
358   }
359 
360   Function *Func = M->getFunction("foo");
361   EXPECT_FALSE(verifyFunction(*Func, &errs()));
362 
363   SmallVector<BasicBlock *, 2> ExtractedBlocks{
364     getBlockByName(Func, "catch"),
365     getBlockByName(Func, "invoke.cont2"),
366     getBlockByName(Func, "invoke.cont3"),
367     getBlockByName(Func, "lpad2")
368   };
369 
370   CodeExtractor CE(ExtractedBlocks);
371   EXPECT_TRUE(CE.isEligible());
372 
373   CodeExtractorAnalysisCache CEAC(*Func);
374   Function *Outlined = CE.extractCodeRegion(CEAC);
375   EXPECT_TRUE(Outlined);
376   EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
377   EXPECT_FALSE(verifyFunction(*Func, &errs()));
378 }
379 
380 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
381   LLVMContext Ctx;
382   SMDiagnostic Err;
383   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
384     declare i32 @bar()
385 
386     define i32 @foo() personality i8* null {
387     entry:
388       %0 = invoke i32 @bar() to label %exit unwind label %lpad
389 
390     exit:
391       ret i32 %0
392 
393     lpad:
394       %1 = landingpad { i8*, i32 }
395               cleanup
396       resume { i8*, i32 } %1
397     }
398   )invalid",
399                                                 Err, Ctx));
400 
401   Function *Func = M->getFunction("foo");
402   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
403                                        getBlockByName(Func, "lpad") };
404 
405   CodeExtractor CE(Blocks);
406   EXPECT_TRUE(CE.isEligible());
407 
408   CodeExtractorAnalysisCache CEAC(*Func);
409   Function *Outlined = CE.extractCodeRegion(CEAC);
410   EXPECT_TRUE(Outlined);
411   EXPECT_FALSE(verifyFunction(*Outlined));
412   EXPECT_FALSE(verifyFunction(*Func));
413 }
414 
415 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) {
416   LLVMContext Ctx;
417   SMDiagnostic Err;
418   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
419         target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
420         target triple = "aarch64"
421 
422         %b = type { i64 }
423         declare void @g(i8*)
424 
425         declare void @llvm.assume(i1) #0
426 
427         define void @test() {
428         entry:
429           br label %label
430 
431         label:
432           %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8
433           %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0
434           %2 = load i64, i64* %1, align 8
435           %3 = icmp ugt i64 %2, 1
436           br i1 %3, label %if.then, label %if.else
437 
438         if.then:
439           unreachable
440 
441         if.else:
442           call void @g(i8* undef)
443           store i64 undef, i64* null, align 536870912
444           %4 = icmp eq i64 %2, 0
445           call void @llvm.assume(i1 %4)
446           unreachable
447         }
448 
449         attributes #0 = { nounwind willreturn }
450   )ir",
451                                                 Err, Ctx));
452 
453   assert(M && "Could not parse module?");
454   Function *Func = M->getFunction("test");
455   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") };
456   AssumptionCache AC(*Func);
457   CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC);
458   EXPECT_TRUE(CE.isEligible());
459 
460   CodeExtractorAnalysisCache CEAC(*Func);
461   Function *Outlined = CE.extractCodeRegion(CEAC);
462   EXPECT_TRUE(Outlined);
463   EXPECT_FALSE(verifyFunction(*Outlined));
464   EXPECT_FALSE(verifyFunction(*Func));
465   EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC));
466 }
467 
468 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
469   LLVMContext Ctx;
470   SMDiagnostic Err;
471   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
472     target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
473     target triple = "x86_64-unknown-linux-gnu"
474 
475     declare void @use(i32*)
476     declare void @llvm.lifetime.start.p0i8(i64, i8*)
477     declare void @llvm.lifetime.end.p0i8(i64, i8*)
478 
479     define void @foo() {
480     entry:
481       %0 = alloca i32
482       br label %extract
483 
484     extract:
485       %1 = bitcast i32* %0 to i8*
486       call void @llvm.lifetime.start.p0i8(i64 4, i8* %1)
487       call void @use(i32* %0)
488       br label %exit
489 
490     exit:
491       call void @use(i32* %0)
492       call void @llvm.lifetime.end.p0i8(i64 4, i8* %1)
493       ret void
494     }
495   )ir",
496                                                 Err, Ctx));
497 
498   Function *Func = M->getFunction("foo");
499   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
500 
501   CodeExtractor CE(Blocks);
502   EXPECT_TRUE(CE.isEligible());
503 
504   CodeExtractorAnalysisCache CEAC(*Func);
505   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
506   BasicBlock *CommonExit = nullptr;
507   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
508   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
509   EXPECT_EQ(Outputs.size(), 0U);
510 
511   Function *Outlined = CE.extractCodeRegion(CEAC);
512   EXPECT_TRUE(Outlined);
513   EXPECT_FALSE(verifyFunction(*Outlined));
514   EXPECT_FALSE(verifyFunction(*Func));
515 }
516 
517 TEST(CodeExtractor, PartialAggregateArgs) {
518   LLVMContext Ctx;
519   SMDiagnostic Err;
520   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
521     target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
522     target triple = "x86_64-unknown-linux-gnu"
523 
524     ; use different types such that an index mismatch will result in a type mismatch during verification.
525     declare void @use16(i16)
526     declare void @use32(i32)
527     declare void @use64(i64)
528 
529     define void @foo(i16 %a, i32 %b, i64 %c) {
530     entry:
531       br label %extract
532 
533     extract:
534       call void @use16(i16 %a)
535       call void @use32(i32 %b)
536       call void @use64(i64 %c)
537       %d = add i16 21, 21
538       %e = add i32 21, 21
539       %f = add i64 21, 21
540       br label %exit
541 
542     exit:
543       call void @use16(i16 %d)
544       call void @use32(i32 %e)
545       call void @use64(i64 %f)
546       ret void
547     }
548   )ir",
549                                                 Err, Ctx));
550 
551   Function *Func = M->getFunction("foo");
552   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
553 
554   // Create the CodeExtractor with arguments aggregation enabled.
555   CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
556                    /* AggregateArgs */ true);
557   EXPECT_TRUE(CE.isEligible());
558 
559   CodeExtractorAnalysisCache CEAC(*Func);
560   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
561   BasicBlock *CommonExit = nullptr;
562   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
563   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
564   // Exclude the middle input and output from the argument aggregate.
565   CE.excludeArgFromAggregate(Inputs[1]);
566   CE.excludeArgFromAggregate(Outputs[1]);
567 
568   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
569   EXPECT_TRUE(Outlined);
570   // Expect 3 arguments in the outlined function: the excluded input, the
571   // excluded output, and the struct aggregate for the remaining inputs.
572   EXPECT_EQ(Outlined->arg_size(), 3U);
573   EXPECT_FALSE(verifyFunction(*Outlined));
574   EXPECT_FALSE(verifyFunction(*Func));
575 }
576 
577 TEST(CodeExtractor, AllocaBlock) {
578   LLVMContext Ctx;
579   SMDiagnostic Err;
580   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
581     define i32 @foo(i32 %x, i32 %y, i32 %z) {
582     entry:
583       br label %allocas
584 
585     allocas:
586       br label %body
587 
588     body:
589       %w = add i32 %x, %y
590       br label %notExtracted
591 
592     notExtracted:
593       %r = add i32 %w, %x
594       ret i32 %r
595     }
596   )invalid",
597                                                 Err, Ctx));
598 
599   Function *Func = M->getFunction("foo");
600   SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "body")};
601 
602   BasicBlock *AllocaBlock = getBlockByName(Func, "allocas");
603   CodeExtractor CE(Candidates, nullptr, true, nullptr, nullptr, nullptr, false,
604                    false, AllocaBlock);
605   CE.excludeArgFromAggregate(Func->getArg(0));
606   CE.excludeArgFromAggregate(getInstByName(Func, "w"));
607   EXPECT_TRUE(CE.isEligible());
608 
609   CodeExtractorAnalysisCache CEAC(*Func);
610   SetVector<Value *> Inputs, Outputs;
611   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
612   EXPECT_TRUE(Outlined);
613   EXPECT_FALSE(verifyFunction(*Outlined));
614   EXPECT_FALSE(verifyFunction(*Func));
615 
616   // The only added allocas may be in the dedicated alloca block. There should
617   // be one alloca for the struct, and another one for the reload value.
618   int NumAllocas = 0;
619   for (Instruction &I : instructions(Func)) {
620     if (!isa<AllocaInst>(I))
621       continue;
622     EXPECT_EQ(I.getParent(), AllocaBlock);
623     NumAllocas += 1;
624   }
625   EXPECT_EQ(NumAllocas, 2);
626 }
627 
628 /// Regression test to ensure we don't crash trying to set the name of the ptr
629 /// argument
630 TEST(CodeExtractor, PartialAggregateArgs2) {
631   LLVMContext Ctx;
632   SMDiagnostic Err;
633   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
634     declare void @usei(i32)
635     declare void @usep(ptr)
636 
637     define void @foo(i32 %a, i32 %b, ptr %p) {
638     entry:
639       br label %extract
640 
641     extract:
642       call void @usei(i32 %a)
643       call void @usei(i32 %b)
644       call void @usep(ptr %p)
645       br label %exit
646 
647     exit:
648       ret void
649     }
650   )ir",
651                                                 Err, Ctx));
652 
653   Function *Func = M->getFunction("foo");
654   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
655 
656   // Create the CodeExtractor with arguments aggregation enabled.
657   CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
658                    /* AggregateArgs */ true);
659   EXPECT_TRUE(CE.isEligible());
660 
661   CodeExtractorAnalysisCache CEAC(*Func);
662   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
663   BasicBlock *CommonExit = nullptr;
664   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
665   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
666   // Exclude the last input from the argument aggregate.
667   CE.excludeArgFromAggregate(Inputs[2]);
668 
669   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
670   EXPECT_TRUE(Outlined);
671   EXPECT_FALSE(verifyFunction(*Outlined));
672   EXPECT_FALSE(verifyFunction(*Func));
673 }
674 
675 TEST(CodeExtractor, OpenMPAggregateArgs) {
676   LLVMContext Ctx;
677   SMDiagnostic Err;
678   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
679     target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"
680     target triple = "amdgcn-amd-amdhsa"
681 
682     define void @foo(ptr %0) {
683       %2= alloca ptr, align 8, addrspace(5)
684       %3 = addrspacecast ptr addrspace(5) %2 to ptr
685       store ptr %0, ptr %3, align 8
686       %4 = load ptr, ptr %3, align 8
687       br label %entry
688 
689    entry:
690       br label %extract
691 
692     extract:
693       store i64 10, ptr %4, align 4
694       br label %exit
695 
696     exit:
697       ret void
698     }
699   )ir",
700                                                 Err, Ctx));
701   Function *Func = M->getFunction("foo");
702   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
703 
704   // Create the CodeExtractor with arguments aggregation enabled.
705   // Outlined function argument should be declared in 0 address space
706   // even if the default alloca address space is 5.
707   CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
708                    /* AggregateArgs */ true, /* BlockFrequencyInfo */ nullptr,
709                    /* BranchProbabilityInfo */ nullptr,
710                    /* AssumptionCache */ nullptr,
711                    /* AllowVarArgs */ true,
712                    /* AllowAlloca */ true,
713                    /* AllocaBlock*/ &Func->getEntryBlock(),
714                    /* Suffix */ ".outlined",
715                    /* ArgsInZeroAddressSpace */ true);
716 
717   EXPECT_TRUE(CE.isEligible());
718 
719   CodeExtractorAnalysisCache CEAC(*Func);
720   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
721   BasicBlock *CommonExit = nullptr;
722   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
723   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
724 
725   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
726   EXPECT_TRUE(Outlined);
727   EXPECT_EQ(Outlined->arg_size(), 1U);
728   // Check address space of outlined argument is ptr in address space 0
729   EXPECT_EQ(Outlined->getArg(0)->getType(),
730             PointerType::get(M->getContext(), 0));
731   EXPECT_FALSE(verifyFunction(*Outlined));
732   EXPECT_FALSE(verifyFunction(*Func));
733 }
734 } // end anonymous namespace
735