1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // Test that the various MachineFunction pass managers, adaptors, analyses, and 9 // analysis managers work. 10 //===----------------------------------------------------------------------===// 11 12 #include "llvm/IR/PassManager.h" 13 #include "llvm/Analysis/CGSCCPassManager.h" 14 #include "llvm/Analysis/LoopAnalysisManager.h" 15 #include "llvm/AsmParser/Parser.h" 16 #include "llvm/CodeGen/MachineFunction.h" 17 #include "llvm/CodeGen/MachineModuleInfo.h" 18 #include "llvm/CodeGen/MachinePassManager.h" 19 #include "llvm/IR/Analysis.h" 20 #include "llvm/IR/LLVMContext.h" 21 #include "llvm/IR/Module.h" 22 #include "llvm/MC/TargetRegistry.h" 23 #include "llvm/Passes/PassBuilder.h" 24 #include "llvm/Support/SourceMgr.h" 25 #include "llvm/Support/TargetSelect.h" 26 #include "llvm/Target/TargetMachine.h" 27 #include "llvm/TargetParser/Host.h" 28 #include "llvm/TargetParser/Triple.h" 29 #include "gtest/gtest.h" 30 31 using namespace llvm; 32 33 namespace { 34 35 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> { 36 public: 37 struct Result { 38 Result(int Count) : InstructionCount(Count) {} 39 int InstructionCount; 40 }; 41 42 /// The number of instructions in the Function. 43 Result run(Function &F, FunctionAnalysisManager &AM) { 44 return Result(F.getInstructionCount()); 45 } 46 47 private: 48 friend AnalysisInfoMixin<TestFunctionAnalysis>; 49 static AnalysisKey Key; 50 }; 51 52 AnalysisKey TestFunctionAnalysis::Key; 53 54 class TestMachineFunctionAnalysis 55 : public AnalysisInfoMixin<TestMachineFunctionAnalysis> { 56 public: 57 struct Result { 58 Result(int Count) : InstructionCount(Count) {} 59 int InstructionCount; 60 }; 61 62 Result run(MachineFunction &MF, MachineFunctionAnalysisManager &AM) { 63 FunctionAnalysisManager &FAM = 64 AM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF) 65 .getManager(); 66 TestFunctionAnalysis::Result &FAR = 67 FAM.getResult<TestFunctionAnalysis>(MF.getFunction()); 68 return FAR.InstructionCount; 69 } 70 71 private: 72 friend AnalysisInfoMixin<TestMachineFunctionAnalysis>; 73 static AnalysisKey Key; 74 }; 75 76 AnalysisKey TestMachineFunctionAnalysis::Key; 77 78 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> { 79 TestMachineFunctionPass(int &Count, std::vector<int> &Counts) 80 : Count(Count), Counts(Counts) {} 81 82 PreservedAnalyses run(MachineFunction &MF, 83 MachineFunctionAnalysisManager &MFAM) { 84 FunctionAnalysisManager &FAM = 85 MFAM.getResult<FunctionAnalysisManagerMachineFunctionProxy>(MF) 86 .getManager(); 87 TestFunctionAnalysis::Result &FAR = 88 FAM.getResult<TestFunctionAnalysis>(MF.getFunction()); 89 Count += FAR.InstructionCount; 90 91 TestMachineFunctionAnalysis::Result &MFAR = 92 MFAM.getResult<TestMachineFunctionAnalysis>(MF); 93 Count += MFAR.InstructionCount; 94 95 Counts.push_back(Count); 96 97 return PreservedAnalyses::none(); 98 } 99 100 int &Count; 101 std::vector<int> &Counts; 102 }; 103 104 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> { 105 TestMachineModulePass(int &Count, std::vector<int> &Counts) 106 : Count(Count), Counts(Counts) {} 107 108 PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM) { 109 MachineModuleInfo &MMI = MAM.getResult<MachineModuleAnalysis>(M).getMMI(); 110 FunctionAnalysisManager &FAM = 111 MAM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); 112 MachineFunctionAnalysisManager &MFAM = 113 MAM.getResult<MachineFunctionAnalysisManagerModuleProxy>(M) 114 .getManager(); 115 for (Function &F : M) { 116 MachineFunction &MF = MMI.getOrCreateMachineFunction(F); 117 Count += FAM.getResult<TestFunctionAnalysis>(F).InstructionCount; 118 Count += MFAM.getResult<TestMachineFunctionAnalysis>(MF).InstructionCount; 119 } 120 Counts.push_back(Count); 121 return PreservedAnalyses::all(); 122 } 123 124 int &Count; 125 std::vector<int> &Counts; 126 }; 127 128 struct ReportWarningPass : public PassInfoMixin<ReportWarningPass> { 129 PreservedAnalyses run(MachineFunction &MF, 130 MachineFunctionAnalysisManager &MFAM) { 131 auto &Ctx = MF.getContext(); 132 Ctx.reportWarning(SMLoc(), "Test warning message."); 133 return PreservedAnalyses::all(); 134 } 135 }; 136 137 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) { 138 SMDiagnostic Err; 139 return parseAssemblyString(IR, Err, Context); 140 } 141 142 class PassManagerTest : public ::testing::Test { 143 protected: 144 LLVMContext Context; 145 std::unique_ptr<Module> M; 146 std::unique_ptr<TargetMachine> TM; 147 148 public: 149 PassManagerTest() 150 : M(parseIR(Context, "define void @f() {\n" 151 "entry:\n" 152 " call void @g()\n" 153 " call void @h()\n" 154 " ret void\n" 155 "}\n" 156 "define void @g() {\n" 157 " ret void\n" 158 "}\n" 159 "define void @h() {\n" 160 " ret void\n" 161 "}\n")) { 162 // MachineModuleAnalysis needs a TargetMachine instance. 163 llvm::InitializeAllTargets(); 164 165 std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple()); 166 std::string Error; 167 const Target *TheTarget = 168 TargetRegistry::lookupTarget(TripleName, Error); 169 if (!TheTarget) 170 return; 171 172 TargetOptions Options; 173 TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options, 174 std::nullopt)); 175 } 176 }; 177 178 TEST_F(PassManagerTest, Basic) { 179 if (!TM) 180 GTEST_SKIP(); 181 182 M->setDataLayout(TM->createDataLayout()); 183 184 MachineModuleInfo MMI(TM.get()); 185 186 MachineFunctionAnalysisManager MFAM; 187 LoopAnalysisManager LAM; 188 FunctionAnalysisManager FAM; 189 CGSCCAnalysisManager CGAM; 190 ModuleAnalysisManager MAM; 191 PassBuilder PB(TM.get()); 192 PB.registerModuleAnalyses(MAM); 193 PB.registerCGSCCAnalyses(CGAM); 194 PB.registerFunctionAnalyses(FAM); 195 PB.registerLoopAnalyses(LAM); 196 PB.registerMachineFunctionAnalyses(MFAM); 197 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM, &MFAM); 198 199 FAM.registerPass([&] { return TestFunctionAnalysis(); }); 200 MAM.registerPass([&] { return MachineModuleAnalysis(MMI); }); 201 MFAM.registerPass([&] { return TestMachineFunctionAnalysis(); }); 202 203 int Count = 0; 204 std::vector<int> Counts; 205 206 ModulePassManager MPM; 207 FunctionPassManager FPM; 208 MachineFunctionPassManager MFPM; 209 MPM.addPass(TestMachineModulePass(Count, Counts)); 210 FPM.addPass(createFunctionToMachineFunctionPassAdaptor( 211 TestMachineFunctionPass(Count, Counts))); 212 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); 213 MPM.addPass(TestMachineModulePass(Count, Counts)); 214 MFPM.addPass(TestMachineFunctionPass(Count, Counts)); 215 FPM = FunctionPassManager(); 216 FPM.addPass(createFunctionToMachineFunctionPassAdaptor(std::move(MFPM))); 217 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); 218 219 testing::internal::CaptureStderr(); 220 MPM.run(*M, MAM); 221 std::string Output = testing::internal::GetCapturedStderr(); 222 223 EXPECT_EQ((std::vector<int>{10, 16, 18, 20, 30, 36, 38, 40}), Counts); 224 EXPECT_EQ(40, Count); 225 } 226 227 TEST_F(PassManagerTest, DiagnosticHandler) { 228 if (!TM) 229 GTEST_SKIP(); 230 231 M->setDataLayout(TM->createDataLayout()); 232 233 MachineModuleInfo MMI(TM.get()); 234 235 LoopAnalysisManager LAM; 236 MachineFunctionAnalysisManager MFAM; 237 FunctionAnalysisManager FAM; 238 CGSCCAnalysisManager CGAM; 239 ModuleAnalysisManager MAM; 240 PassBuilder PB(TM.get()); 241 PB.registerModuleAnalyses(MAM); 242 PB.registerCGSCCAnalyses(CGAM); 243 PB.registerFunctionAnalyses(FAM); 244 PB.registerLoopAnalyses(LAM); 245 PB.registerMachineFunctionAnalyses(MFAM); 246 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM, &MFAM); 247 248 MAM.registerPass([&] { return MachineModuleAnalysis(MMI); }); 249 250 ModulePassManager MPM; 251 FunctionPassManager FPM; 252 MachineFunctionPassManager MFPM; 253 MPM.addPass(RequireAnalysisPass<MachineModuleAnalysis, Module>()); 254 MFPM.addPass(ReportWarningPass()); 255 FPM.addPass(createFunctionToMachineFunctionPassAdaptor(std::move(MFPM))); 256 MPM.addPass(createModuleToFunctionPassAdaptor(std::move(FPM))); 257 testing::internal::CaptureStderr(); 258 MPM.run(*M, MAM); 259 std::string Output = testing::internal::GetCapturedStderr(); 260 261 EXPECT_TRUE(Output.find("warning: <unknown>:0: Test warning message.") != 262 std::string::npos); 263 } 264 265 } // namespace 266