xref: /llvm-project/llvm/unittests/CodeGen/PassManagerTest.cpp (revision d768bf994f508d7eaf9541a568be3d71096febf5)
1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/CGSCCPassManager.h"
10 #include "llvm/Analysis/LoopAnalysisManager.h"
11 #include "llvm/AsmParser/Parser.h"
12 #include "llvm/CodeGen/MachineFunction.h"
13 #include "llvm/CodeGen/MachineModuleInfo.h"
14 #include "llvm/CodeGen/MachinePassManager.h"
15 #include "llvm/IR/LLVMContext.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/MC/TargetRegistry.h"
18 #include "llvm/Passes/PassBuilder.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Support/TargetSelect.h"
21 #include "llvm/Target/TargetMachine.h"
22 #include "llvm/TargetParser/Host.h"
23 #include "llvm/TargetParser/Triple.h"
24 #include "gtest/gtest.h"
25 
26 using namespace llvm;
27 
28 namespace {
29 
30 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> {
31 public:
32   struct Result {
33     Result(int Count) : InstructionCount(Count) {}
34     int InstructionCount;
35   };
36 
37   /// Run the analysis pass over the function and return a result.
38   Result run(Function &F, FunctionAnalysisManager &AM) {
39     int Count = 0;
40     for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI)
41       for (BasicBlock::iterator II = BBI->begin(), IE = BBI->end(); II != IE;
42            ++II)
43         ++Count;
44     return Result(Count);
45   }
46 
47 private:
48   friend AnalysisInfoMixin<TestFunctionAnalysis>;
49   static AnalysisKey Key;
50 };
51 
52 AnalysisKey TestFunctionAnalysis::Key;
53 
54 class TestMachineFunctionAnalysis
55     : public AnalysisInfoMixin<TestMachineFunctionAnalysis> {
56 public:
57   struct Result {
58     Result(int Count) : InstructionCount(Count) {}
59     int InstructionCount;
60   };
61 
62   /// Run the analysis pass over the machine function and return a result.
63   Result run(MachineFunction &MF, MachineFunctionAnalysisManager::Base &AM) {
64     auto &MFAM = static_cast<MachineFunctionAnalysisManager &>(AM);
65     // Query function analysis result.
66     TestFunctionAnalysis::Result &FAR =
67         MFAM.getResult<TestFunctionAnalysis>(MF.getFunction());
68     // + 5
69     return FAR.InstructionCount;
70   }
71 
72 private:
73   friend AnalysisInfoMixin<TestMachineFunctionAnalysis>;
74   static AnalysisKey Key;
75 };
76 
77 AnalysisKey TestMachineFunctionAnalysis::Key;
78 
79 const std::string DoInitErrMsg = "doInitialization failed";
80 const std::string DoFinalErrMsg = "doFinalization failed";
81 
82 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> {
83   TestMachineFunctionPass(int &Count, std::vector<int> &BeforeInitialization,
84                           std::vector<int> &BeforeFinalization,
85                           std::vector<int> &MachineFunctionPassCount)
86       : Count(Count), BeforeInitialization(BeforeInitialization),
87         BeforeFinalization(BeforeFinalization),
88         MachineFunctionPassCount(MachineFunctionPassCount) {}
89 
90   Error doInitialization(Module &M, MachineFunctionAnalysisManager &MFAM) {
91     // Force doInitialization fail by starting with big `Count`.
92     if (Count > 10000)
93       return make_error<StringError>(DoInitErrMsg, inconvertibleErrorCode());
94 
95     // + 1
96     ++Count;
97     BeforeInitialization.push_back(Count);
98     return Error::success();
99   }
100   Error doFinalization(Module &M, MachineFunctionAnalysisManager &MFAM) {
101     // Force doFinalization fail by starting with big `Count`.
102     if (Count > 1000)
103       return make_error<StringError>(DoFinalErrMsg, inconvertibleErrorCode());
104 
105     // + 1
106     ++Count;
107     BeforeFinalization.push_back(Count);
108     return Error::success();
109   }
110 
111   PreservedAnalyses run(MachineFunction &MF,
112                         MachineFunctionAnalysisManager &MFAM) {
113     // Query function analysis result.
114     TestFunctionAnalysis::Result &FAR =
115         MFAM.getResult<TestFunctionAnalysis>(MF.getFunction());
116     // 3 + 1 + 1 = 5
117     Count += FAR.InstructionCount;
118 
119     // Query module analysis result.
120     MachineModuleInfo &MMI =
121         MFAM.getResult<MachineModuleAnalysis>(*MF.getFunction().getParent());
122     // 1 + 1 + 1 = 3
123     Count += (MMI.getModule() == MF.getFunction().getParent());
124 
125     // Query machine function analysis result.
126     TestMachineFunctionAnalysis::Result &MFAR =
127         MFAM.getResult<TestMachineFunctionAnalysis>(MF);
128     // 3 + 1 + 1 = 5
129     Count += MFAR.InstructionCount;
130 
131     MachineFunctionPassCount.push_back(Count);
132 
133     return PreservedAnalyses::none();
134   }
135 
136   int &Count;
137   std::vector<int> &BeforeInitialization;
138   std::vector<int> &BeforeFinalization;
139   std::vector<int> &MachineFunctionPassCount;
140 };
141 
142 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> {
143   TestMachineModulePass(int &Count, std::vector<int> &MachineModulePassCount)
144       : Count(Count), MachineModulePassCount(MachineModulePassCount) {}
145 
146   Error run(Module &M, MachineFunctionAnalysisManager &MFAM) {
147     MachineModuleInfo &MMI = MFAM.getResult<MachineModuleAnalysis>(M);
148     // + 1
149     Count += (MMI.getModule() == &M);
150     MachineModulePassCount.push_back(Count);
151     return Error::success();
152   }
153 
154   PreservedAnalyses run(MachineFunction &MF,
155                         MachineFunctionAnalysisManager &AM) {
156     llvm_unreachable(
157         "This should never be reached because this is machine module pass");
158   }
159 
160   int &Count;
161   std::vector<int> &MachineModulePassCount;
162 };
163 
164 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
165   SMDiagnostic Err;
166   return parseAssemblyString(IR, Err, Context);
167 }
168 
169 class PassManagerTest : public ::testing::Test {
170 protected:
171   LLVMContext Context;
172   std::unique_ptr<Module> M;
173   std::unique_ptr<TargetMachine> TM;
174 
175 public:
176   PassManagerTest()
177       : M(parseIR(Context, "define void @f() {\n"
178                            "entry:\n"
179                            "  call void @g()\n"
180                            "  call void @h()\n"
181                            "  ret void\n"
182                            "}\n"
183                            "define void @g() {\n"
184                            "  ret void\n"
185                            "}\n"
186                            "define void @h() {\n"
187                            "  ret void\n"
188                            "}\n")) {
189     // MachineModuleAnalysis needs a TargetMachine instance.
190     llvm::InitializeAllTargets();
191 
192     std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
193     std::string Error;
194     const Target *TheTarget =
195         TargetRegistry::lookupTarget(TripleName, Error);
196     if (!TheTarget)
197       return;
198 
199     TargetOptions Options;
200     TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options,
201                                             std::nullopt));
202   }
203 };
204 
205 TEST_F(PassManagerTest, Basic) {
206   if (!TM)
207     GTEST_SKIP();
208 
209   LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get());
210   M->setDataLayout(TM->createDataLayout());
211 
212   LoopAnalysisManager LAM;
213   FunctionAnalysisManager FAM;
214   CGSCCAnalysisManager CGAM;
215   ModuleAnalysisManager MAM;
216   PassBuilder PB(TM.get());
217   PB.registerModuleAnalyses(MAM);
218   PB.registerFunctionAnalyses(FAM);
219   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
220 
221   FAM.registerPass([&] { return TestFunctionAnalysis(); });
222   FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
223   MAM.registerPass([&] { return MachineModuleAnalysis(LLVMTM); });
224   MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
225 
226   MachineFunctionAnalysisManager MFAM;
227   {
228     // Test move assignment.
229     MachineFunctionAnalysisManager NestedMFAM(FAM, MAM);
230     NestedMFAM.registerPass([&] { return PassInstrumentationAnalysis(); });
231     NestedMFAM.registerPass([&] { return TestMachineFunctionAnalysis(); });
232     MFAM = std::move(NestedMFAM);
233   }
234 
235   int Count = 0;
236   std::vector<int> BeforeInitialization[2];
237   std::vector<int> BeforeFinalization[2];
238   std::vector<int> TestMachineFunctionCount[2];
239   std::vector<int> TestMachineModuleCount[2];
240 
241   MachineFunctionPassManager MFPM;
242   {
243     // Test move assignment.
244     MachineFunctionPassManager NestedMFPM;
245     NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[0]));
246     NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[0],
247                                                BeforeFinalization[0],
248                                                TestMachineFunctionCount[0]));
249     NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[1]));
250     NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
251                                                BeforeFinalization[1],
252                                                TestMachineFunctionCount[1]));
253     MFPM = std::move(NestedMFPM);
254   }
255 
256   ASSERT_FALSE(errorToBool(MFPM.run(*M, MFAM)));
257 
258   // Check first machine module pass
259   EXPECT_EQ(1u, TestMachineModuleCount[0].size());
260   EXPECT_EQ(3, TestMachineModuleCount[0][0]);
261 
262   // Check first machine function pass
263   EXPECT_EQ(1u, BeforeInitialization[0].size());
264   EXPECT_EQ(1, BeforeInitialization[0][0]);
265   EXPECT_EQ(3u, TestMachineFunctionCount[0].size());
266   EXPECT_EQ(10, TestMachineFunctionCount[0][0]);
267   EXPECT_EQ(13, TestMachineFunctionCount[0][1]);
268   EXPECT_EQ(16, TestMachineFunctionCount[0][2]);
269   EXPECT_EQ(1u, BeforeFinalization[0].size());
270   EXPECT_EQ(31, BeforeFinalization[0][0]);
271 
272   // Check second machine module pass
273   EXPECT_EQ(1u, TestMachineModuleCount[1].size());
274   EXPECT_EQ(17, TestMachineModuleCount[1][0]);
275 
276   // Check second machine function pass
277   EXPECT_EQ(1u, BeforeInitialization[1].size());
278   EXPECT_EQ(2, BeforeInitialization[1][0]);
279   EXPECT_EQ(3u, TestMachineFunctionCount[1].size());
280   EXPECT_EQ(24, TestMachineFunctionCount[1][0]);
281   EXPECT_EQ(27, TestMachineFunctionCount[1][1]);
282   EXPECT_EQ(30, TestMachineFunctionCount[1][2]);
283   EXPECT_EQ(1u, BeforeFinalization[1].size());
284   EXPECT_EQ(32, BeforeFinalization[1][0]);
285 
286   EXPECT_EQ(32, Count);
287 
288   // doInitialization returns error
289   Count = 10000;
290   MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
291                                        BeforeFinalization[1],
292                                        TestMachineFunctionCount[1]));
293   std::string Message;
294   llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) {
295     Message = Error.getMessage();
296   });
297   EXPECT_EQ(Message, DoInitErrMsg);
298 
299   // doFinalization returns error
300   Count = 1000;
301   MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1],
302                                        BeforeFinalization[1],
303                                        TestMachineFunctionCount[1]));
304   llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) {
305     Message = Error.getMessage();
306   });
307   EXPECT_EQ(Message, DoFinalErrMsg);
308 }
309 
310 } // namespace
311