xref: /llvm-project/llvm/unittests/CodeGen/PassManagerTest.cpp (revision bb3f5e1fed7c6ba733b7f273e93f5d3930976185)
1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Test that the various MachineFunction pass managers, adaptors, analyses, and
9 // analysis managers work.
10 //===----------------------------------------------------------------------===//
11 
12 #include "llvm/IR/PassManager.h"
13 #include "llvm/Analysis/CGSCCPassManager.h"
14 #include "llvm/Analysis/LoopAnalysisManager.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/CodeGen/MachineFunction.h"
17 #include "llvm/CodeGen/MachineModuleInfo.h"
18 #include "llvm/CodeGen/MachinePassManager.h"
19 #include "llvm/IR/Analysis.h"
20 #include "llvm/IR/LLVMContext.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/MC/TargetRegistry.h"
23 #include "llvm/Passes/PassBuilder.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/Support/TargetSelect.h"
26 #include "llvm/Target/TargetMachine.h"
27 #include "llvm/TargetParser/Host.h"
28 #include "llvm/TargetParser/Triple.h"
29 #include "gtest/gtest.h"
30 
31 using namespace llvm;
32 
33 namespace {
34 
35 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> {
36 public:
37   struct Result {
38     Result(int Count) : InstructionCount(Count) {}
39     int InstructionCount;
40   };
41 
42   /// The number of instructions in the Function.
43   Result run(Function &F, FunctionAnalysisManager &AM) {
44     return Result(F.getInstructionCount());
45   }
46 
47 private:
48   friend AnalysisInfoMixin<TestFunctionAnalysis>;
49   static AnalysisKey Key;
50 };
51 
52 AnalysisKey TestFunctionAnalysis::Key;
53 
54 class TestMachineFunctionAnalysis
55     : public AnalysisInfoMixin<TestMachineFunctionAnalysis> {
56 public:
57   struct Result {
58     Result(int Count) : InstructionCount(Count) {}
59     int InstructionCount;
60   };
61 
62   Result run(MachineFunction &MF, MachineFunctionAnalysisManager &AM) {
63     FunctionAnalysisManager &FAM =
64         AM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
65             .getManager();
66     TestFunctionAnalysis::Result &FAR =
67         FAM.getResult<TestFunctionAnalysis>(MF.getFunction());
68     return FAR.InstructionCount;
69   }
70 
71 private:
72   friend AnalysisInfoMixin<TestMachineFunctionAnalysis>;
73   static AnalysisKey Key;
74 };
75 
76 AnalysisKey TestMachineFunctionAnalysis::Key;
77 
78 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> {
79   TestMachineFunctionPass(int &Count, std::vector<int> &Counts)
80       : Count(Count), Counts(Counts) {}
81 
82   PreservedAnalyses run(MachineFunction &MF,
83                         MachineFunctionAnalysisManager &MFAM) {
84     FunctionAnalysisManager &FAM =
85         MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
86             .getManager();
87     TestFunctionAnalysis::Result &FAR =
88         FAM.getResult<TestFunctionAnalysis>(MF.getFunction());
89     Count += FAR.InstructionCount;
90 
91     TestMachineFunctionAnalysis::Result &MFAR =
92         MFAM.getResult<TestMachineFunctionAnalysis>(MF);
93     Count += MFAR.InstructionCount;
94 
95     Counts.push_back(Count);
96 
97     return PreservedAnalyses::none();
98   }
99 
100   int &Count;
101   std::vector<int> &Counts;
102 };
103 
104 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> {
105   TestMachineModulePass(int &Count, std::vector<int> &Counts)
106       : Count(Count), Counts(Counts) {}
107 
108   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM) {
109     MachineModuleInfo &MMI = MAM.getResult<MachineModuleAnalysis>(M).getMMI();
110     FunctionAnalysisManager &FAM =
111         MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
112     MachineFunctionAnalysisManager &MFAM =
113         MAM.getResult<MachineFunctionAnalysisManagerModuleProxy>(M)
114             .getManager();
115     for (Function &F : M) {
116       MachineFunction &MF = MMI.getOrCreateMachineFunction(F);
117       Count += FAM.getResult<TestFunctionAnalysis>(F).InstructionCount;
118       Count += MFAM.getResult<TestMachineFunctionAnalysis>(MF).InstructionCount;
119     }
120     Counts.push_back(Count);
121     return PreservedAnalyses::all();
122   }
123 
124   int &Count;
125   std::vector<int> &Counts;
126 };
127 
128 struct ReportWarningPass : public PassInfoMixin<ReportWarningPass> {
129   PreservedAnalyses run(MachineFunction &MF,
130                         MachineFunctionAnalysisManager &MFAM) {
131     auto &Ctx = MF.getContext();
132     Ctx.reportWarning(SMLoc(), "Test warning message.");
133     return PreservedAnalyses::all();
134   }
135 };
136 
137 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
138   SMDiagnostic Err;
139   return parseAssemblyString(IR, Err, Context);
140 }
141 
142 class PassManagerTest : public ::testing::Test {
143 protected:
144   LLVMContext Context;
145   std::unique_ptr<Module> M;
146   std::unique_ptr<TargetMachine> TM;
147 
148 public:
149   PassManagerTest()
150       : M(parseIR(Context, "define void @f() {\n"
151                            "entry:\n"
152                            "  call void @g()\n"
153                            "  call void @h()\n"
154                            "  ret void\n"
155                            "}\n"
156                            "define void @g() {\n"
157                            "  ret void\n"
158                            "}\n"
159                            "define void @h() {\n"
160                            "  ret void\n"
161                            "}\n")) {
162     // MachineModuleAnalysis needs a TargetMachine instance.
163     llvm::InitializeAllTargets();
164 
165     std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
166     std::string Error;
167     const Target *TheTarget =
168         TargetRegistry::lookupTarget(TripleName, Error);
169     if (!TheTarget)
170       return;
171 
172     TargetOptions Options;
173     TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options,
174                                             std::nullopt));
175   }
176 };
177 
178 TEST_F(PassManagerTest, Basic) {
179   if (!TM)
180     GTEST_SKIP();
181 
182   M->setDataLayout(TM->createDataLayout());
183 
184   MachineModuleInfo MMI(TM.get());
185 
186   MachineFunctionAnalysisManager MFAM;
187   LoopAnalysisManager LAM;
188   FunctionAnalysisManager FAM;
189   CGSCCAnalysisManager CGAM;
190   ModuleAnalysisManager MAM;
191   PassBuilder PB(TM.get());
192   PB.registerModuleAnalyses(MAM);
193   PB.registerCGSCCAnalyses(CGAM);
194   PB.registerFunctionAnalyses(FAM);
195   PB.registerLoopAnalyses(LAM);
196   PB.registerMachineFunctionAnalyses(MFAM);
197   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM, &MFAM);
198 
199   FAM.registerPass([&] { return TestFunctionAnalysis(); });
200   MAM.registerPass([&] { return MachineModuleAnalysis(MMI); });
201   MFAM.registerPass([&] { return TestMachineFunctionAnalysis(); });
202 
203   int Count = 0;
204   std::vector<int> Counts;
205 
206   ModulePassManager MPM;
207   FunctionPassManager FPM;
208   MachineFunctionPassManager MFPM;
209   MPM.addPass(TestMachineModulePass(Count, Counts));
210   FPM.addPass(createFunctionToMachineFunctionPassAdaptor(
211       TestMachineFunctionPass(Count, Counts)));
212   MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
213   MPM.addPass(TestMachineModulePass(Count, Counts));
214   MFPM.addPass(TestMachineFunctionPass(Count, Counts));
215   FPM = FunctionPassManager();
216   FPM.addPass(createFunctionToMachineFunctionPassAdaptor(std::move(MFPM)));
217   MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
218 
219   testing::internal::CaptureStderr();
220   MPM.run(*M, MAM);
221   std::string Output = testing::internal::GetCapturedStderr();
222 
223   EXPECT_EQ((std::vector<int>{10, 16, 18, 20, 30, 36, 38, 40}), Counts);
224   EXPECT_EQ(40, Count);
225 }
226 
227 TEST_F(PassManagerTest, DiagnosticHandler) {
228   if (!TM)
229     GTEST_SKIP();
230 
231   M->setDataLayout(TM->createDataLayout());
232 
233   MachineModuleInfo MMI(TM.get());
234 
235   LoopAnalysisManager LAM;
236   MachineFunctionAnalysisManager MFAM;
237   FunctionAnalysisManager FAM;
238   CGSCCAnalysisManager CGAM;
239   ModuleAnalysisManager MAM;
240   PassBuilder PB(TM.get());
241   PB.registerModuleAnalyses(MAM);
242   PB.registerCGSCCAnalyses(CGAM);
243   PB.registerFunctionAnalyses(FAM);
244   PB.registerLoopAnalyses(LAM);
245   PB.registerMachineFunctionAnalyses(MFAM);
246   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM, &MFAM);
247 
248   MAM.registerPass([&] { return MachineModuleAnalysis(MMI); });
249 
250   ModulePassManager MPM;
251   FunctionPassManager FPM;
252   MachineFunctionPassManager MFPM;
253   MPM.addPass(RequireAnalysisPass<MachineModuleAnalysis, Module>());
254   MFPM.addPass(ReportWarningPass());
255   FPM.addPass(createFunctionToMachineFunctionPassAdaptor(std::move(MFPM)));
256   MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM)));
257   testing::internal::CaptureStderr();
258   MPM.run(*M, MAM);
259   std::string Output = testing::internal::GetCapturedStderr();
260 
261   EXPECT_TRUE(Output.find("warning: <unknown>:0: Test warning message.") !=
262               std::string::npos);
263 }
264 
265 } // namespace
266