xref: /llvm-project/llvm/unittests/CodeGen/PassManagerTest.cpp (revision 91e9e3175268c85f4d0e8828d0d392191c250543)
1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 // Test that the various MachineFunction pass managers, adaptors, analyses, and
9 // analysis managers work.
10 //===----------------------------------------------------------------------===//
11 
12 #include "llvm/IR/PassManager.h"
13 #include "llvm/Analysis/CGSCCPassManager.h"
14 #include "llvm/Analysis/LoopAnalysisManager.h"
15 #include "llvm/AsmParser/Parser.h"
16 #include "llvm/CodeGen/MachineFunction.h"
17 #include "llvm/CodeGen/MachineModuleInfo.h"
18 #include "llvm/CodeGen/MachinePassManager.h"
19 #include "llvm/IR/Analysis.h"
20 #include "llvm/IR/LLVMContext.h"
21 #include "llvm/IR/Module.h"
22 #include "llvm/MC/TargetRegistry.h"
23 #include "llvm/Passes/PassBuilder.h"
24 #include "llvm/Support/SourceMgr.h"
25 #include "llvm/Support/TargetSelect.h"
26 #include "llvm/Target/TargetMachine.h"
27 #include "llvm/TargetParser/Host.h"
28 #include "llvm/TargetParser/Triple.h"
29 #include "gtest/gtest.h"
30 
31 using namespace llvm;
32 
33 namespace {
34 
35 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> {
36 public:
37   struct Result {
38     Result(int Count) : InstructionCount(Count) {}
39     int InstructionCount;
40   };
41 
42   /// The number of instructions in the Function.
43   Result run(Function &F, FunctionAnalysisManager &AM) {
44     return Result(F.getInstructionCount());
45   }
46 
47 private:
48   friend AnalysisInfoMixin<TestFunctionAnalysis>;
49   static AnalysisKey Key;
50 };
51 
52 AnalysisKey TestFunctionAnalysis::Key;
53 
54 class TestMachineFunctionAnalysis
55     : public AnalysisInfoMixin<TestMachineFunctionAnalysis> {
56 public:
57   struct Result {
58     Result(int Count) : InstructionCount(Count) {}
59     int InstructionCount;
60   };
61 
62   Result run(MachineFunction &MF, MachineFunctionAnalysisManager &AM) {
63     FunctionAnalysisManager &FAM =
64         AM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
65             .getManager();
66     TestFunctionAnalysis::Result &FAR =
67         FAM.getResult<TestFunctionAnalysis>(MF.getFunction());
68     return FAR.InstructionCount;
69   }
70 
71 private:
72   friend AnalysisInfoMixin<TestMachineFunctionAnalysis>;
73   static AnalysisKey Key;
74 };
75 
76 AnalysisKey TestMachineFunctionAnalysis::Key;
77 
78 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> {
79   TestMachineFunctionPass(int &Count, std::vector<int> &Counts)
80       : Count(Count), Counts(Counts) {}
81 
82   PreservedAnalyses run(MachineFunction &MF,
83                         MachineFunctionAnalysisManager &MFAM) {
84     FunctionAnalysisManager &FAM =
85         MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF)
86             .getManager();
87     TestFunctionAnalysis::Result &FAR =
88         FAM.getResult<TestFunctionAnalysis>(MF.getFunction());
89     Count += FAR.InstructionCount;
90 
91     TestMachineFunctionAnalysis::Result &MFAR =
92         MFAM.getResult<TestMachineFunctionAnalysis>(MF);
93     Count += MFAR.InstructionCount;
94 
95     Counts.push_back(Count);
96 
97     return PreservedAnalyses::none();
98   }
99 
100   int &Count;
101   std::vector<int> &Counts;
102 };
103 
104 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> {
105   TestMachineModulePass(int &Count, std::vector<int> &Counts)
106       : Count(Count), Counts(Counts) {}
107 
108   PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM) {
109     MachineModuleInfo &MMI = MAM.getResult<MachineModuleAnalysis>(M).getMMI();
110     FunctionAnalysisManager &FAM =
111         MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
112     MachineFunctionAnalysisManager &MFAM =
113         MAM.getResult<MachineFunctionAnalysisManagerModuleProxy>(M)
114             .getManager();
115     for (Function &F : M) {
116       MachineFunction &MF = MMI.getOrCreateMachineFunction(F);
117       Count += FAM.getResult<TestFunctionAnalysis>(F).InstructionCount;
118       Count += MFAM.getResult<TestMachineFunctionAnalysis>(MF).InstructionCount;
119     }
120     Counts.push_back(Count);
121     return PreservedAnalyses::all();
122   }
123 
124   int &Count;
125   std::vector<int> &Counts;
126 };
127 
128 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) {
129   SMDiagnostic Err;
130   return parseAssemblyString(IR, Err, Context);
131 }
132 
133 class PassManagerTest : public ::testing::Test {
134 protected:
135   LLVMContext Context;
136   std::unique_ptr<Module> M;
137   std::unique_ptr<TargetMachine> TM;
138 
139 public:
140   PassManagerTest()
141       : M(parseIR(Context, "define void @f() {\n"
142                            "entry:\n"
143                            "  call void @g()\n"
144                            "  call void @h()\n"
145                            "  ret void\n"
146                            "}\n"
147                            "define void @g() {\n"
148                            "  ret void\n"
149                            "}\n"
150                            "define void @h() {\n"
151                            "  ret void\n"
152                            "}\n")) {
153     // MachineModuleAnalysis needs a TargetMachine instance.
154     llvm::InitializeAllTargets();
155 
156     std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple());
157     std::string Error;
158     const Target *TheTarget =
159         TargetRegistry::lookupTarget(TripleName, Error);
160     if (!TheTarget)
161       return;
162 
163     TargetOptions Options;
164     TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options,
165                                             std::nullopt));
166   }
167 };
168 
169 TEST_F(PassManagerTest, Basic) {
170   if (!TM)
171     GTEST_SKIP();
172 
173   LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get());
174   M->setDataLayout(TM->createDataLayout());
175 
176   MachineModuleInfo MMI(LLVMTM);
177 
178   LoopAnalysisManager LAM;
179   FunctionAnalysisManager FAM;
180   CGSCCAnalysisManager CGAM;
181   ModuleAnalysisManager MAM;
182   MachineFunctionAnalysisManager MFAM;
183   PassBuilder PB(TM.get());
184   PB.registerModuleAnalyses(MAM);
185   PB.registerCGSCCAnalyses(CGAM);
186   PB.registerFunctionAnalyses(FAM);
187   PB.registerLoopAnalyses(LAM);
188   PB.registerMachineFunctionAnalyses(MFAM);
189   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM, &MFAM);
190 
191   FAM.registerPass([&] { return TestFunctionAnalysis(); });
192   MAM.registerPass([&] { return MachineModuleAnalysis(MMI); });
193   MFAM.registerPass([&] { return TestMachineFunctionAnalysis(); });
194 
195   int Count = 0;
196   std::vector<int> Counts;
197 
198   ModulePassManager MPM;
199   MachineFunctionPassManager MFPM;
200   MPM.addPass(TestMachineModulePass(Count, Counts));
201   MPM.addPass(createModuleToMachineFunctionPassAdaptor(
202       TestMachineFunctionPass(Count, Counts)));
203   MPM.addPass(TestMachineModulePass(Count, Counts));
204   MFPM.addPass(TestMachineFunctionPass(Count, Counts));
205   MPM.addPass(createModuleToMachineFunctionPassAdaptor(std::move(MFPM)));
206 
207   MPM.run(*M, MAM);
208 
209   EXPECT_EQ((std::vector<int>{10, 16, 18, 20, 30, 36, 38, 40}), Counts);
210   EXPECT_EQ(40, Count);
211 }
212 
213 } // namespace
214