xref: /llvm-project/llvm/unittests/CodeGen/PassManagerTest.cpp (revision f6ca37bf1809b2fa8f6615d4a30eadf8f479c700)
1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Test that the various MachineFunction pass managers, adaptors, analyses, and
9 // analysis managers work.
10 //===----------------------------------------------------------------------===//
11 
12 #include "llvm/IR/PassManager.h"
13 #include "llvm/Analysis/CGSCCPassManager.h"
14 #include "llvm/Analysis/LoopAnalysisManager.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/CodeGen/MachineFunction.h"
17 #include "llvm/CodeGen/MachineModuleInfo.h"
18 #include "llvm/CodeGen/MachinePassManager.h"
19 #include "llvm/IR/Analysis.h"
20 #include "llvm/IR/LLVMContext.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/MC/TargetRegistry.h"
23 #include "llvm/Passes/PassBuilder.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/Support/TargetSelect.h"
26 #include "llvm/Target/TargetMachine.h"
27 #include "llvm/TargetParser/Host.h"
28 #include "llvm/TargetParser/Triple.h"
29 #include "gtest/gtest.h"
30 
31 using namespace llvm;
32 
33 namespace {
34 
35 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> {
36 public:
37   struct Result {
38     Result(int Count) : InstructionCount(Count) {}
39     int InstructionCount;
40   };
41 
42   /// The number of instructions in the Function.
43   Result run(Function &F, FunctionAnalysisManager &AM) {
44     return Result(F.getInstructionCount());
45   }
46 
47 private:
48   friend AnalysisInfoMixin<TestFunctionAnalysis>;
49   static AnalysisKey Key;
50 };
51 
52 AnalysisKey TestFunctionAnalysis::Key;
53 
54 class TestMachineFunctionAnalysis
55     : public AnalysisInfoMixin<TestMachineFunctionAnalysis> {
56 public:
57   struct Result {
58     Result(int Count) : InstructionCount(Count) {}
59     int InstructionCount;
60   };
61 
62   Result run(MachineFunction &MF, MachineFunctionAnalysisManager &AM) {
63     FunctionAnalysisManager &FAM =
64         AM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
65             .getManager();
66     TestFunctionAnalysis::Result &FAR =
67         FAM.getResult<TestFunctionAnalysis>(MF.getFunction());
68     return FAR.InstructionCount;
69   }
70 
71 private:
72   friend AnalysisInfoMixin<TestMachineFunctionAnalysis>;
73   static AnalysisKey Key;
74 };
75 
76 AnalysisKey TestMachineFunctionAnalysis::Key;
77 
78 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> {
79   TestMachineFunctionPass(int &Count, std::vector<int> &Counts)
80       : Count(Count), Counts(Counts) {}
81 
82   PreservedAnalyses run(MachineFunction &MF,
83                         MachineFunctionAnalysisManager &MFAM) {
84     FunctionAnalysisManager &FAM =
85         MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
86             .getManager();
87     TestFunctionAnalysis::Result &FAR =
88         FAM.getResult<TestFunctionAnalysis>(MF.getFunction());
89     Count += FAR.InstructionCount;
90 
91     TestMachineFunctionAnalysis::Result &MFAR =
92         MFAM.getResult<TestMachineFunctionAnalysis>(MF);
93     Count += MFAR.InstructionCount;
94 
95     Counts.push_back(Count);
96 
97     return PreservedAnalyses::none();
98   }
99 
100   int &Count;
101   std::vector<int> &Counts;
102 };
103 
104 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> {
105   TestMachineModulePass(int &Count, std::vector<int> &Counts)
106       : Count(Count), Counts(Counts) {}
107 
108   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM) {
109     MachineModuleInfo &MMI = MAM.getResult<MachineModuleAnalysis>(M).getMMI();
110     FunctionAnalysisManager &FAM =
111         MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
112     MachineFunctionAnalysisManager &MFAM =
113         MAM.getResult<MachineFunctionAnalysisManagerModuleProxy>(M)
114             .getManager();
115     for (Function &F : M) {
116       MachineFunction &MF = MMI.getOrCreateMachineFunction(F);
117       Count += FAM.getResult<TestFunctionAnalysis>(F).InstructionCount;
118       Count += MFAM.getResult<TestMachineFunctionAnalysis>(MF).InstructionCount;
119     }
120     Counts.push_back(Count);
121     return PreservedAnalyses::all();
122   }
123 
124   int &Count;
125   std::vector<int> &Counts;
126 };
127 
128 struct ReportWarningPass : public PassInfoMixin<ReportWarningPass> {
129   PreservedAnalyses run(MachineFunction &MF,
130                         MachineFunctionAnalysisManager &MFAM) {
131     auto &Ctx = MF.getContext();
132     Ctx.reportWarning(SMLoc(), "Test warning message.");
133     return PreservedAnalyses::all();
134   }
135 };
136 
137 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
138   SMDiagnostic Err;
139   return parseAssemblyString(IR, Err, Context);
140 }
141 
142 class PassManagerTest : public ::testing::Test {
143 protected:
144   LLVMContext Context;
145   std::unique_ptr<Module> M;
146   std::unique_ptr<TargetMachine> TM;
147 
148 public:
149   PassManagerTest()
150       : M(parseIR(Context, "define void @f() {\n"
151                            "entry:\n"
152                            "  call void @g()\n"
153                            "  call void @h()\n"
154                            "  ret void\n"
155                            "}\n"
156                            "define void @g() {\n"
157                            "  ret void\n"
158                            "}\n"
159                            "define void @h() {\n"
160                            "  ret void\n"
161                            "}\n")) {
162     // MachineModuleAnalysis needs a TargetMachine instance.
163     llvm::InitializeAllTargets();
164 
165     std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
166     std::string Error;
167     const Target *TheTarget =
168         TargetRegistry::lookupTarget(TripleName, Error);
169     if (!TheTarget)
170       return;
171 
172     TargetOptions Options;
173     TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options,
174                                             std::nullopt));
175   }
176 };
177 
178 TEST_F(PassManagerTest, Basic) {
179   if (!TM)
180     GTEST_SKIP();
181 
182   LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get());
183   M->setDataLayout(TM->createDataLayout());
184 
185   MachineModuleInfo MMI(LLVMTM);
186 
187   LoopAnalysisManager LAM;
188   MachineFunctionAnalysisManager MFAM;
189   FunctionAnalysisManager FAM;
190   CGSCCAnalysisManager CGAM;
191   ModuleAnalysisManager MAM;
192   PassBuilder PB(TM.get());
193   PB.registerModuleAnalyses(MAM);
194   PB.registerCGSCCAnalyses(CGAM);
195   PB.registerFunctionAnalyses(FAM);
196   PB.registerLoopAnalyses(LAM);
197   PB.registerMachineFunctionAnalyses(MFAM);
198   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM, &MFAM);
199 
200   FAM.registerPass([&] { return TestFunctionAnalysis(); });
201   MAM.registerPass([&] { return MachineModuleAnalysis(MMI); });
202   MFAM.registerPass([&] { return TestMachineFunctionAnalysis(); });
203 
204   int Count = 0;
205   std::vector<int> Counts;
206 
207   ModulePassManager MPM;
208   MachineFunctionPassManager MFPM;
209   MPM.addPass(TestMachineModulePass(Count, Counts));
210   MPM.addPass(createModuleToMachineFunctionPassAdaptor(
211       TestMachineFunctionPass(Count, Counts)));
212   MPM.addPass(TestMachineModulePass(Count, Counts));
213   MFPM.addPass(TestMachineFunctionPass(Count, Counts));
214   MPM.addPass(createModuleToMachineFunctionPassAdaptor(std::move(MFPM)));
215 
216   testing::internal::CaptureStderr();
217   MPM.run(*M, MAM);
218   std::string Output = testing::internal::GetCapturedStderr();
219 
220   EXPECT_EQ((std::vector<int>{10, 16, 18, 20, 30, 36, 38, 40}), Counts);
221   EXPECT_EQ(40, Count);
222 }
223 
224 TEST_F(PassManagerTest, DiagnosticHandler) {
225   if (!TM)
226     GTEST_SKIP();
227 
228   LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get());
229   M->setDataLayout(TM->createDataLayout());
230 
231   MachineModuleInfo MMI(LLVMTM);
232 
233   LoopAnalysisManager LAM;
234   MachineFunctionAnalysisManager MFAM;
235   FunctionAnalysisManager FAM;
236   CGSCCAnalysisManager CGAM;
237   ModuleAnalysisManager MAM;
238   PassBuilder PB(TM.get());
239   PB.registerModuleAnalyses(MAM);
240   PB.registerCGSCCAnalyses(CGAM);
241   PB.registerFunctionAnalyses(FAM);
242   PB.registerLoopAnalyses(LAM);
243   PB.registerMachineFunctionAnalyses(MFAM);
244   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM, &MFAM);
245 
246   MAM.registerPass([&] { return MachineModuleAnalysis(MMI); });
247 
248   ModulePassManager MPM;
249   FunctionPassManager FPM;
250   MachineFunctionPassManager MFPM;
251   MFPM.addPass(ReportWarningPass());
252   MPM.addPass(createModuleToMachineFunctionPassAdaptor(std::move(MFPM)));
253   testing::internal::CaptureStderr();
254   MPM.run(*M, MAM);
255   std::string Output = testing::internal::GetCapturedStderr();
256 
257   EXPECT_TRUE(Output.find("warning: <unknown>:0: Test warning message.") !=
258               std::string::npos);
259 }
260 
261 } // namespace
262