xref: /llvm-project/llvm/unittests/Transforms/Utils/CodeExtractorTest.cpp (revision d129569e348a65252ed1d88e0edce997462ecc24)
1 //===- CodeExtractor.cpp - Unit tests for CodeExtractor -------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include "llvm/Transforms/Utils/CodeExtractor.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/IR/BasicBlock.h"
13 #include "llvm/IR/Dominators.h"
14 #include "llvm/IR/Instructions.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/Verifier.h"
18 #include "llvm/IRReader/IRReader.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "gtest/gtest.h"
21 
22 using namespace llvm;
23 
24 namespace {
25 BasicBlock *getBlockByName(Function *F, StringRef name) {
26   for (auto &BB : *F)
27     if (BB.getName() == name)
28       return &BB;
29   return nullptr;
30 }
31 
32 TEST(CodeExtractor, ExitStub) {
33   LLVMContext Ctx;
34   SMDiagnostic Err;
35   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
36     define i32 @foo(i32 %x, i32 %y, i32 %z) {
37     header:
38       %0 = icmp ugt i32 %x, %y
39       br i1 %0, label %body1, label %body2
40 
41     body1:
42       %1 = add i32 %z, 2
43       br label %notExtracted
44 
45     body2:
46       %2 = mul i32 %z, 7
47       br label %notExtracted
48 
49     notExtracted:
50       %3 = phi i32 [ %1, %body1 ], [ %2, %body2 ]
51       %4 = add i32 %3, %x
52       ret i32 %4
53     }
54   )invalid",
55                                                 Err, Ctx));
56 
57   Function *Func = M->getFunction("foo");
58   SmallVector<BasicBlock *, 3> Candidates{ getBlockByName(Func, "header"),
59                                            getBlockByName(Func, "body1"),
60                                            getBlockByName(Func, "body2") };
61 
62   DominatorTree DT(*Func);
63   CodeExtractor CE(Candidates, &DT);
64   EXPECT_TRUE(CE.isEligible());
65 
66   Function *Outlined = CE.extractCodeRegion();
67   EXPECT_TRUE(Outlined);
68   BasicBlock *Exit = getBlockByName(Func, "notExtracted");
69   BasicBlock *ExitSplit = getBlockByName(Outlined, "notExtracted.split");
70   // Ensure that PHI in exit block has only one incoming value (from code
71   // replacer block).
72   EXPECT_TRUE(Exit && cast<PHINode>(Exit->front()).getNumIncomingValues() == 1);
73   // Ensure that there is a PHI in outlined function with 2 incoming values.
74   EXPECT_TRUE(ExitSplit &&
75               cast<PHINode>(ExitSplit->front()).getNumIncomingValues() == 2);
76   EXPECT_FALSE(verifyFunction(*Outlined));
77   EXPECT_FALSE(verifyFunction(*Func));
78 }
79 
80 TEST(CodeExtractor, ExitPHIOnePredFromRegion) {
81   LLVMContext Ctx;
82   SMDiagnostic Err;
83   std::unique_ptr<Module> M(parseAssemblyString(R"invalid(
84     define i32 @foo() {
85     header:
86       br i1 undef, label %extracted1, label %pred
87 
88     pred:
89       br i1 undef, label %exit1, label %exit2
90 
91     extracted1:
92       br i1 undef, label %extracted2, label %exit1
93 
94     extracted2:
95       br label %exit2
96 
97     exit1:
98       %0 = phi i32 [ 1, %extracted1 ], [ 2, %pred ]
99       ret i32 %0
100 
101     exit2:
102       %1 = phi i32 [ 3, %extracted2 ], [ 4, %pred ]
103       ret i32 %1
104     }
105   )invalid", Err, Ctx));
106 
107   Function *Func = M->getFunction("foo");
108   SmallVector<BasicBlock *, 2> ExtractedBlocks{
109     getBlockByName(Func, "extracted1"),
110     getBlockByName(Func, "extracted2")
111   };
112 
113   DominatorTree DT(*Func);
114   CodeExtractor CE(ExtractedBlocks, &DT);
115   EXPECT_TRUE(CE.isEligible());
116 
117   Function *Outlined = CE.extractCodeRegion();
118   EXPECT_TRUE(Outlined);
119   BasicBlock *Exit1 = getBlockByName(Func, "exit1");
120   BasicBlock *Exit2 = getBlockByName(Func, "exit2");
121   // Ensure that PHIs in exits are not splitted (since that they have only one
122   // incoming value from extracted region).
123   EXPECT_TRUE(Exit1 &&
124           cast<PHINode>(Exit1->front()).getNumIncomingValues() == 2);
125   EXPECT_TRUE(Exit2 &&
126           cast<PHINode>(Exit2->front()).getNumIncomingValues() == 2);
127   EXPECT_FALSE(verifyFunction(*Outlined));
128   EXPECT_FALSE(verifyFunction(*Func));
129 }
130 } // end anonymous namespace
131