xref: /llvm-project/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp (revision 4aaa92578686176243a294eeb2ca5697a99edcaa)
1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Utils/CodeExtractor.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/Analysis/AssumptionCache.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/Verifier.h"
19 #include "llvm/IRReader/IRReader.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "gtest/gtest.h"
22 
23 using namespace llvm;
24 
25 namespace {
26 BasicBlock *getBlockByName(Function *F, StringRef name) {
27   for (auto &BB : *F)
28     if (BB.getName() == name)
29       return &BB;
30   return nullptr;
31 }
32 
33 TEST(CodeExtractor, ExitStub) {
34   LLVMContext Ctx;
35   SMDiagnostic Err;
36   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
37     define i32 @foo(i32 %x, i32 %y, i32 %z) {
38     header:
39       %0 = icmp ugt i32 %x, %y
40       br i1 %0, label %body1, label %body2
41 
42     body1:
43       %1 = add i32 %z, 2
44       br label %notExtracted
45 
46     body2:
47       %2 = mul i32 %z, 7
48       br label %notExtracted
49 
50     notExtracted:
51       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
52       %4 = add i32 %3, %x
53       ret i32 %4
54     }
55   )invalid",
56                                                 Err, Ctx));
57 
58   Function *Func = M->getFunction("foo");
59   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
60                                            getBlockByName(Func, "body1"),
61                                            getBlockByName(Func, "body2") };
62 
63   CodeExtractor CE(Candidates);
64   EXPECT_TRUE(CE.isEligible());
65 
66   CodeExtractorAnalysisCache CEAC(*Func);
67   Function *Outlined = CE.extractCodeRegion(CEAC);
68   EXPECT_TRUE(Outlined);
69   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
70   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
71   // Ensure that PHI in exit block has only one incoming value (from code
72   // replacer block).
73   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
74   // Ensure that there is a PHI in outlined function with 2 incoming values.
75   EXPECT_TRUE(ExitSplit &&
76               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
77   EXPECT_FALSE(verifyFunction(*Outlined));
78   EXPECT_FALSE(verifyFunction(*Func));
79 }
80 
81 TEST(CodeExtractor, InputOutputMonitoring) {
82   LLVMContext Ctx;
83   SMDiagnostic Err;
84   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
85     define i32 @foo(i32 %x, i32 %y, i32 %z) {
86     header:
87       %0 = icmp ugt i32 %x, %y
88       br i1 %0, label %body1, label %body2
89 
90     body1:
91       %1 = add i32 %z, 2
92       br label %notExtracted
93 
94     body2:
95       %2 = mul i32 %z, 7
96       br label %notExtracted
97 
98     notExtracted:
99       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
100       %4 = add i32 %3, %x
101       ret i32 %4
102     }
103   )invalid",
104                                                 Err, Ctx));
105 
106   Function *Func = M->getFunction("foo");
107   SmallVector<BasicBlock *, 3> Candidates{getBlockByName(Func, "header"),
108                                           getBlockByName(Func, "body1"),
109                                           getBlockByName(Func, "body2")};
110 
111   CodeExtractor CE(Candidates);
112   EXPECT_TRUE(CE.isEligible());
113 
114   CodeExtractorAnalysisCache CEAC(*Func);
115   SetVector<Value *> Inputs, Outputs;
116   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
117   EXPECT_TRUE(Outlined);
118 
119   EXPECT_EQ(Inputs.size(), 3u);
120   EXPECT_EQ(Inputs[0], Func->getArg(2));
121   EXPECT_EQ(Inputs[1], Func->getArg(0));
122   EXPECT_EQ(Inputs[2], Func->getArg(1));
123   EXPECT_EQ(Outputs.size(), 1u);
124   StoreInst *SI = cast<StoreInst>(Outlined->getArg(3)->user_back());
125   Value *OutputVal = SI->getValueOperand();
126   EXPECT_EQ(Outputs[0], OutputVal);
127   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
128   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
129   // Ensure that PHI in exit block has only one incoming value (from code
130   // replacer block).
131   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
132   // Ensure that there is a PHI in outlined function with 2 incoming values.
133   EXPECT_TRUE(ExitSplit &&
134               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
135   EXPECT_FALSE(verifyFunction(*Outlined));
136   EXPECT_FALSE(verifyFunction(*Func));
137 }
138 
139 TEST(CodeExtractor, ExitBlockOrderingPhis) {
140   LLVMContext Ctx;
141   SMDiagnostic Err;
142   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
143     define void @foo(i32 %a, i32 %b) {
144     entry:
145       %0 = alloca i32, align 4
146       br label %test0
147     test0:
148       %c = load i32, i32* %0, align 4
149       br label %test1
150     test1:
151       %e = load i32, i32* %0, align 4
152       br i1 true, label %first, label %test
153     test:
154       %d = load i32, i32* %0, align 4
155       br i1 true, label %first, label %next
156     first:
157       %1 = phi i32 [ %c, %test ], [ %e, %test1 ]
158       ret void
159     next:
160       %2 = add i32 %d, 1
161       %3 = add i32 %e, 1
162       ret void
163     }
164   )invalid",
165                                                 Err, Ctx));
166   Function *Func = M->getFunction("foo");
167   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
168                                            getBlockByName(Func, "test1"),
169                                            getBlockByName(Func, "test") };
170 
171   CodeExtractor CE(Candidates);
172   EXPECT_TRUE(CE.isEligible());
173 
174   CodeExtractorAnalysisCache CEAC(*Func);
175   Function *Outlined = CE.extractCodeRegion(CEAC);
176   EXPECT_TRUE(Outlined);
177 
178   BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
179   BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
180 
181   Instruction *FirstTerm = FirstExitStub->getTerminator();
182   ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
183   EXPECT_TRUE(FirstReturn);
184   ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
185   EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
186 
187   Instruction *NextTerm = NextExitStub->getTerminator();
188   ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
189   EXPECT_TRUE(NextReturn);
190   ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
191   EXPECT_TRUE(CINext->getLimitedValue() == 0u);
192 
193   EXPECT_FALSE(verifyFunction(*Outlined));
194   EXPECT_FALSE(verifyFunction(*Func));
195 }
196 
197 TEST(CodeExtractor, ExitBlockOrdering) {
198   LLVMContext Ctx;
199   SMDiagnostic Err;
200   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
201     define void @foo(i32 %a, i32 %b) {
202     entry:
203       %0 = alloca i32, align 4
204       br label %test0
205     test0:
206       %c = load i32, i32* %0, align 4
207       br label %test1
208     test1:
209       %e = load i32, i32* %0, align 4
210       br i1 true, label %first, label %test
211     test:
212       %d = load i32, i32* %0, align 4
213       br i1 true, label %first, label %next
214     first:
215       ret void
216     next:
217       %1 = add i32 %d, 1
218       %2 = add i32 %e, 1
219       ret void
220     }
221   )invalid",
222                                                 Err, Ctx));
223   Function *Func = M->getFunction("foo");
224   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "test0"),
225                                            getBlockByName(Func, "test1"),
226                                            getBlockByName(Func, "test") };
227 
228   CodeExtractor CE(Candidates);
229   EXPECT_TRUE(CE.isEligible());
230 
231   CodeExtractorAnalysisCache CEAC(*Func);
232   Function *Outlined = CE.extractCodeRegion(CEAC);
233   EXPECT_TRUE(Outlined);
234 
235   BasicBlock *FirstExitStub = getBlockByName(Outlined, "first.exitStub");
236   BasicBlock *NextExitStub = getBlockByName(Outlined, "next.exitStub");
237 
238   Instruction *FirstTerm = FirstExitStub->getTerminator();
239   ReturnInst *FirstReturn = dyn_cast<ReturnInst>(FirstTerm);
240   EXPECT_TRUE(FirstReturn);
241   ConstantInt *CIFirst = dyn_cast<ConstantInt>(FirstReturn->getReturnValue());
242   EXPECT_TRUE(CIFirst->getLimitedValue() == 1u);
243 
244   Instruction *NextTerm = NextExitStub->getTerminator();
245   ReturnInst *NextReturn = dyn_cast<ReturnInst>(NextTerm);
246   EXPECT_TRUE(NextReturn);
247   ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
248   EXPECT_TRUE(CINext->getLimitedValue() == 0u);
249 
250   EXPECT_FALSE(verifyFunction(*Outlined));
251   EXPECT_FALSE(verifyFunction(*Func));
252 }
253 
254 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
255   LLVMContext Ctx;
256   SMDiagnostic Err;
257   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
258     define i32 @foo() {
259     header:
260       br i1 undef, label %extracted1, label %pred
261 
262     pred:
263       br i1 undef, label %exit1, label %exit2
264 
265     extracted1:
266       br i1 undef, label %extracted2, label %exit1
267 
268     extracted2:
269       br label %exit2
270 
271     exit1:
272       %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
273       ret i32 %0
274 
275     exit2:
276       %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
277       ret i32 %1
278     }
279   )invalid", Err, Ctx));
280 
281   Function *Func = M->getFunction("foo");
282   SmallVector<BasicBlock *, 2> ExtractedBlocks{
283     getBlockByName(Func, "extracted1"),
284     getBlockByName(Func, "extracted2")
285   };
286 
287   CodeExtractor CE(ExtractedBlocks);
288   EXPECT_TRUE(CE.isEligible());
289 
290   CodeExtractorAnalysisCache CEAC(*Func);
291   Function *Outlined = CE.extractCodeRegion(CEAC);
292   EXPECT_TRUE(Outlined);
293   BasicBlock *Exit1 = getBlockByName(Func, "exit1");
294   BasicBlock *Exit2 = getBlockByName(Func, "exit2");
295   // Ensure that PHIs in exits are not splitted (since that they have only one
296   // incoming value from extracted region).
297   EXPECT_TRUE(Exit1 &&
298           cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
299   EXPECT_TRUE(Exit2 &&
300           cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
301   EXPECT_FALSE(verifyFunction(*Outlined));
302   EXPECT_FALSE(verifyFunction(*Func));
303 }
304 
305 TEST(CodeExtractor, StoreOutputInvokeResultAfterEHPad) {
306   LLVMContext Ctx;
307   SMDiagnostic Err;
308   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
309     declare i8 @hoge()
310 
311     define i32 @foo() personality i8* null {
312       entry:
313         %call = invoke i8 @hoge()
314                 to label %invoke.cont unwind label %lpad
315 
316       invoke.cont:                                      ; preds = %entry
317         unreachable
318 
319       lpad:                                             ; preds = %entry
320         %0 = landingpad { i8*, i32 }
321                 catch i8* null
322         br i1 undef, label %catch, label %finally.catchall
323 
324       catch:                                            ; preds = %lpad
325         %call2 = invoke i8 @hoge()
326                 to label %invoke.cont2 unwind label %lpad2
327 
328       invoke.cont2:                                    ; preds = %catch
329         %call3 = invoke i8 @hoge()
330                 to label %invoke.cont3 unwind label %lpad2
331 
332       invoke.cont3:                                    ; preds = %invoke.cont2
333         unreachable
334 
335       lpad2:                                           ; preds = %invoke.cont2, %catch
336         %ex.1 = phi i8* [ undef, %invoke.cont2 ], [ null, %catch ]
337         %1 = landingpad { i8*, i32 }
338                 catch i8* null
339         br label %finally.catchall
340 
341       finally.catchall:                                 ; preds = %lpad33, %lpad
342         %ex.2 = phi i8* [ %ex.1, %lpad2 ], [ null, %lpad ]
343         unreachable
344     }
345   )invalid", Err, Ctx));
346 
347 	if (!M) {
348     Err.print("unit", errs());
349     exit(1);
350   }
351 
352   Function *Func = M->getFunction("foo");
353   EXPECT_FALSE(verifyFunction(*Func, &errs()));
354 
355   SmallVector<BasicBlock *, 2> ExtractedBlocks{
356     getBlockByName(Func, "catch"),
357     getBlockByName(Func, "invoke.cont2"),
358     getBlockByName(Func, "invoke.cont3"),
359     getBlockByName(Func, "lpad2")
360   };
361 
362   CodeExtractor CE(ExtractedBlocks);
363   EXPECT_TRUE(CE.isEligible());
364 
365   CodeExtractorAnalysisCache CEAC(*Func);
366   Function *Outlined = CE.extractCodeRegion(CEAC);
367   EXPECT_TRUE(Outlined);
368   EXPECT_FALSE(verifyFunction(*Outlined, &errs()));
369   EXPECT_FALSE(verifyFunction(*Func, &errs()));
370 }
371 
372 TEST(CodeExtractor, StoreOutputInvokeResultInExitStub) {
373   LLVMContext Ctx;
374   SMDiagnostic Err;
375   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
376     declare i32 @bar()
377 
378     define i32 @foo() personality i8* null {
379     entry:
380       %0 = invoke i32 @bar() to label %exit unwind label %lpad
381 
382     exit:
383       ret i32 %0
384 
385     lpad:
386       %1 = landingpad { i8*, i32 }
387               cleanup
388       resume { i8*, i32 } %1
389     }
390   )invalid",
391                                                 Err, Ctx));
392 
393   Function *Func = M->getFunction("foo");
394   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "entry"),
395                                        getBlockByName(Func, "lpad") };
396 
397   CodeExtractor CE(Blocks);
398   EXPECT_TRUE(CE.isEligible());
399 
400   CodeExtractorAnalysisCache CEAC(*Func);
401   Function *Outlined = CE.extractCodeRegion(CEAC);
402   EXPECT_TRUE(Outlined);
403   EXPECT_FALSE(verifyFunction(*Outlined));
404   EXPECT_FALSE(verifyFunction(*Func));
405 }
406 
407 TEST(CodeExtractor, ExtractAndInvalidateAssumptionCache) {
408   LLVMContext Ctx;
409   SMDiagnostic Err;
410   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
411         target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128"
412         target triple = "aarch64"
413 
414         %b = type { i64 }
415         declare void @g(i8*)
416 
417         declare void @llvm.assume(i1) #0
418 
419         define void @test() {
420         entry:
421           br label %label
422 
423         label:
424           %0 = load %b*, %b** inttoptr (i64 8 to %b**), align 8
425           %1 = getelementptr inbounds %b, %b* %0, i64 undef, i32 0
426           %2 = load i64, i64* %1, align 8
427           %3 = icmp ugt i64 %2, 1
428           br i1 %3, label %if.then, label %if.else
429 
430         if.then:
431           unreachable
432 
433         if.else:
434           call void @g(i8* undef)
435           store i64 undef, i64* null, align 536870912
436           %4 = icmp eq i64 %2, 0
437           call void @llvm.assume(i1 %4)
438           unreachable
439         }
440 
441         attributes #0 = { nounwind willreturn }
442   )ir",
443                                                 Err, Ctx));
444 
445   assert(M && "Could not parse module?");
446   Function *Func = M->getFunction("test");
447   SmallVector<BasicBlock *, 1> Blocks{ getBlockByName(Func, "if.else") };
448   AssumptionCache AC(*Func);
449   CodeExtractor CE(Blocks, nullptr, false, nullptr, nullptr, &AC);
450   EXPECT_TRUE(CE.isEligible());
451 
452   CodeExtractorAnalysisCache CEAC(*Func);
453   Function *Outlined = CE.extractCodeRegion(CEAC);
454   EXPECT_TRUE(Outlined);
455   EXPECT_FALSE(verifyFunction(*Outlined));
456   EXPECT_FALSE(verifyFunction(*Func));
457   EXPECT_FALSE(CE.verifyAssumptionCache(*Func, *Outlined, &AC));
458 }
459 
460 TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
461   LLVMContext Ctx;
462   SMDiagnostic Err;
463   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
464     target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
465     target triple = "x86_64-unknown-linux-gnu"
466 
467     declare void @use(i32*)
468     declare void @llvm.lifetime.start.p0i8(i64, i8*)
469     declare void @llvm.lifetime.end.p0i8(i64, i8*)
470 
471     define void @foo() {
472     entry:
473       %0 = alloca i32
474       br label %extract
475 
476     extract:
477       %1 = bitcast i32* %0 to i8*
478       call void @llvm.lifetime.start.p0i8(i64 4, i8* %1)
479       call void @use(i32* %0)
480       br label %exit
481 
482     exit:
483       call void @use(i32* %0)
484       call void @llvm.lifetime.end.p0i8(i64 4, i8* %1)
485       ret void
486     }
487   )ir",
488                                                 Err, Ctx));
489 
490   Function *Func = M->getFunction("foo");
491   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
492 
493   CodeExtractor CE(Blocks);
494   EXPECT_TRUE(CE.isEligible());
495 
496   CodeExtractorAnalysisCache CEAC(*Func);
497   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
498   BasicBlock *CommonExit = nullptr;
499   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
500   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
501   EXPECT_EQ(Outputs.size(), 0U);
502 
503   Function *Outlined = CE.extractCodeRegion(CEAC);
504   EXPECT_TRUE(Outlined);
505   EXPECT_FALSE(verifyFunction(*Outlined));
506   EXPECT_FALSE(verifyFunction(*Func));
507 }
508 
509 TEST(CodeExtractor, PartialAggregateArgs) {
510   LLVMContext Ctx;
511   SMDiagnostic Err;
512   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
513     target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
514     target triple = "x86_64-unknown-linux-gnu"
515 
516     declare void @use(i32)
517 
518     define void @foo(i32 %a, i32 %b, i32 %c) {
519     entry:
520       br label %extract
521 
522     extract:
523       call void @use(i32 %a)
524       call void @use(i32 %b)
525       call void @use(i32 %c)
526       br label %exit
527 
528     exit:
529       ret void
530     }
531   )ir",
532                                                 Err, Ctx));
533 
534   Function *Func = M->getFunction("foo");
535   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
536 
537   // Create the CodeExtractor with arguments aggregation enabled.
538   CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
539                    /* AggregateArgs */ true);
540   EXPECT_TRUE(CE.isEligible());
541 
542   CodeExtractorAnalysisCache CEAC(*Func);
543   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
544   BasicBlock *CommonExit = nullptr;
545   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
546   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
547   // Exclude the first input from the argument aggregate.
548   CE.excludeArgFromAggregate(Inputs[0]);
549 
550   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
551   EXPECT_TRUE(Outlined);
552   // Expect 2 arguments in the outlined function: the excluded input and the
553   // struct aggregate for the remaining inputs.
554   EXPECT_EQ(Outlined->arg_size(), 2U);
555   EXPECT_FALSE(verifyFunction(*Outlined));
556   EXPECT_FALSE(verifyFunction(*Func));
557 }
558 
559 /// Regression test to ensure we don't crash trying to set the name of the ptr
560 /// argument
561 TEST(CodeExtractor, PartialAggregateArgs2) {
562   LLVMContext Ctx;
563   SMDiagnostic Err;
564   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
565     declare void @usei(i32)
566     declare void @usep(ptr)
567 
568     define void @foo(i32 %a, i32 %b, ptr %p) {
569     entry:
570       br label %extract
571 
572     extract:
573       call void @usei(i32 %a)
574       call void @usei(i32 %b)
575       call void @usep(ptr %p)
576       br label %exit
577 
578     exit:
579       ret void
580     }
581   )ir",
582                                                 Err, Ctx));
583 
584   Function *Func = M->getFunction("foo");
585   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
586 
587   // Create the CodeExtractor with arguments aggregation enabled.
588   CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
589                    /* AggregateArgs */ true);
590   EXPECT_TRUE(CE.isEligible());
591 
592   CodeExtractorAnalysisCache CEAC(*Func);
593   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
594   BasicBlock *CommonExit = nullptr;
595   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
596   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
597   // Exclude the last input from the argument aggregate.
598   CE.excludeArgFromAggregate(Inputs[2]);
599 
600   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
601   EXPECT_TRUE(Outlined);
602   EXPECT_FALSE(verifyFunction(*Outlined));
603   EXPECT_FALSE(verifyFunction(*Func));
604 }
605 
606 TEST(CodeExtractor, OpenMPAggregateArgs) {
607   LLVMContext Ctx;
608   SMDiagnostic Err;
609   std::unique_ptr<Module> M(parseAssemblyString(R"ir(
610     target datalayout = "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8:9"
611     target triple = "amdgcn-amd-amdhsa"
612 
613     define void @foo(ptr %0) {
614       %2= alloca ptr, align 8, addrspace(5)
615       %3 = addrspacecast ptr addrspace(5) %2 to ptr
616       store ptr %0, ptr %3, align 8
617       %4 = load ptr, ptr %3, align 8
618       br label %entry
619 
620    entry:
621       br label %extract
622 
623     extract:
624       store i64 10, ptr %4, align 4
625       br label %exit
626 
627     exit:
628       ret void
629     }
630   )ir",
631                                                 Err, Ctx));
632   Function *Func = M->getFunction("foo");
633   SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
634 
635   // Create the CodeExtractor with arguments aggregation enabled.
636   // Outlined function argument should be declared in 0 address space
637   // even if the default alloca address space is 5.
638   CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
639                    /* AggregateArgs */ true, /* BlockFrequencyInfo */ nullptr,
640                    /* BranchProbabilityInfo */ nullptr,
641                    /* AssumptionCache */ nullptr,
642                    /* AllowVarArgs */ true,
643                    /* AllowAlloca */ true,
644                    /* AllocaBlock*/ &Func->getEntryBlock(),
645                    /* Suffix */ ".outlined",
646                    /* ArgsInZeroAddressSpace */ true);
647 
648   EXPECT_TRUE(CE.isEligible());
649 
650   CodeExtractorAnalysisCache CEAC(*Func);
651   SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
652   BasicBlock *CommonExit = nullptr;
653   CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
654   CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
655 
656   Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
657   EXPECT_TRUE(Outlined);
658   EXPECT_EQ(Outlined->arg_size(), 1U);
659   // Check address space of outlined argument is ptr in address space 0
660   EXPECT_EQ(Outlined->getArg(0)->getType(),
661             PointerType::get(M->getContext(), 0));
662   EXPECT_FALSE(verifyFunction(*Outlined));
663   EXPECT_FALSE(verifyFunction(*Func));
664 }
665 } // end anonymous namespace
666