1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/CGSCCPassManager.h" 10 #include "llvm/Analysis/LoopAnalysisManager.h" 11 #include "llvm/AsmParser/Parser.h" 12 #include "llvm/CodeGen/MachineFunction.h" 13 #include "llvm/CodeGen/MachineModuleInfo.h" 14 #include "llvm/CodeGen/MachinePassManager.h" 15 #include "llvm/IR/LLVMContext.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/MC/TargetRegistry.h" 18 #include "llvm/Passes/PassBuilder.h" 19 #include "llvm/Support/SourceMgr.h" 20 #include "llvm/Support/TargetSelect.h" 21 #include "llvm/Target/TargetMachine.h" 22 #include "llvm/TargetParser/Host.h" 23 #include "llvm/TargetParser/Triple.h" 24 #include "gtest/gtest.h" 25 26 using namespace llvm; 27 28 namespace { 29 30 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> { 31 public: 32 struct Result { 33 Result(int Count) : InstructionCount(Count) {} 34 int InstructionCount; 35 }; 36 37 /// Run the analysis pass over the function and return a result. 38 Result run(Function &F, FunctionAnalysisManager &AM) { 39 int Count = 0; 40 for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) 41 for (BasicBlock::iterator II = BBI->begin(), IE = BBI->end(); II != IE; 42 ++II) 43 ++Count; 44 return Result(Count); 45 } 46 47 private: 48 friend AnalysisInfoMixin<TestFunctionAnalysis>; 49 static AnalysisKey Key; 50 }; 51 52 AnalysisKey TestFunctionAnalysis::Key; 53 54 class TestMachineFunctionAnalysis 55 : public AnalysisInfoMixin<TestMachineFunctionAnalysis> { 56 public: 57 struct Result { 58 Result(int Count) : InstructionCount(Count) {} 59 int InstructionCount; 60 }; 61 62 /// Run the analysis pass over the machine function and return a result. 63 Result run(MachineFunction &MF, MachineFunctionAnalysisManager::Base &AM) { 64 auto &MFAM = static_cast<MachineFunctionAnalysisManager &>(AM); 65 // Query function analysis result. 66 TestFunctionAnalysis::Result &FAR = 67 MFAM.getResult<TestFunctionAnalysis>(MF.getFunction()); 68 // + 5 69 return FAR.InstructionCount; 70 } 71 72 private: 73 friend AnalysisInfoMixin<TestMachineFunctionAnalysis>; 74 static AnalysisKey Key; 75 }; 76 77 AnalysisKey TestMachineFunctionAnalysis::Key; 78 79 const std::string DoInitErrMsg = "doInitialization failed"; 80 const std::string DoFinalErrMsg = "doFinalization failed"; 81 82 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> { 83 TestMachineFunctionPass(int &Count, std::vector<int> &BeforeInitialization, 84 std::vector<int> &BeforeFinalization, 85 std::vector<int> &MachineFunctionPassCount) 86 : Count(Count), BeforeInitialization(BeforeInitialization), 87 BeforeFinalization(BeforeFinalization), 88 MachineFunctionPassCount(MachineFunctionPassCount) {} 89 90 Error doInitialization(Module &M, MachineFunctionAnalysisManager &MFAM) { 91 // Force doInitialization fail by starting with big `Count`. 92 if (Count > 10000) 93 return make_error<StringError>(DoInitErrMsg, inconvertibleErrorCode()); 94 95 // + 1 96 ++Count; 97 BeforeInitialization.push_back(Count); 98 return Error::success(); 99 } 100 Error doFinalization(Module &M, MachineFunctionAnalysisManager &MFAM) { 101 // Force doFinalization fail by starting with big `Count`. 102 if (Count > 1000) 103 return make_error<StringError>(DoFinalErrMsg, inconvertibleErrorCode()); 104 105 // + 1 106 ++Count; 107 BeforeFinalization.push_back(Count); 108 return Error::success(); 109 } 110 111 PreservedAnalyses run(MachineFunction &MF, 112 MachineFunctionAnalysisManager &MFAM) { 113 // Query function analysis result. 114 TestFunctionAnalysis::Result &FAR = 115 MFAM.getResult<TestFunctionAnalysis>(MF.getFunction()); 116 // 3 + 1 + 1 = 5 117 Count += FAR.InstructionCount; 118 119 // Query module analysis result. 120 MachineModuleInfo &MMI = 121 MFAM.getResult<MachineModuleAnalysis>(*MF.getFunction().getParent()); 122 // 1 + 1 + 1 = 3 123 Count += (MMI.getModule() == MF.getFunction().getParent()); 124 125 // Query machine function analysis result. 126 TestMachineFunctionAnalysis::Result &MFAR = 127 MFAM.getResult<TestMachineFunctionAnalysis>(MF); 128 // 3 + 1 + 1 = 5 129 Count += MFAR.InstructionCount; 130 131 MachineFunctionPassCount.push_back(Count); 132 133 return PreservedAnalyses::none(); 134 } 135 136 int &Count; 137 std::vector<int> &BeforeInitialization; 138 std::vector<int> &BeforeFinalization; 139 std::vector<int> &MachineFunctionPassCount; 140 }; 141 142 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> { 143 TestMachineModulePass(int &Count, std::vector<int> &MachineModulePassCount) 144 : Count(Count), MachineModulePassCount(MachineModulePassCount) {} 145 146 Error run(Module &M, MachineFunctionAnalysisManager &MFAM) { 147 MachineModuleInfo &MMI = MFAM.getResult<MachineModuleAnalysis>(M); 148 // + 1 149 Count += (MMI.getModule() == &M); 150 MachineModulePassCount.push_back(Count); 151 return Error::success(); 152 } 153 154 PreservedAnalyses run(MachineFunction &MF, 155 MachineFunctionAnalysisManager &AM) { 156 llvm_unreachable( 157 "This should never be reached because this is machine module pass"); 158 } 159 160 int &Count; 161 std::vector<int> &MachineModulePassCount; 162 }; 163 164 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) { 165 SMDiagnostic Err; 166 return parseAssemblyString(IR, Err, Context); 167 } 168 169 class PassManagerTest : public ::testing::Test { 170 protected: 171 LLVMContext Context; 172 std::unique_ptr<Module> M; 173 std::unique_ptr<TargetMachine> TM; 174 175 public: 176 PassManagerTest() 177 : M(parseIR(Context, "define void @f() {\n" 178 "entry:\n" 179 " call void @g()\n" 180 " call void @h()\n" 181 " ret void\n" 182 "}\n" 183 "define void @g() {\n" 184 " ret void\n" 185 "}\n" 186 "define void @h() {\n" 187 " ret void\n" 188 "}\n")) { 189 // MachineModuleAnalysis needs a TargetMachine instance. 190 llvm::InitializeAllTargets(); 191 192 std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple()); 193 std::string Error; 194 const Target *TheTarget = 195 TargetRegistry::lookupTarget(TripleName, Error); 196 if (!TheTarget) 197 return; 198 199 TargetOptions Options; 200 TM.reset(TheTarget->createTargetMachine(TripleName, "", "", Options, 201 std::nullopt)); 202 } 203 }; 204 205 TEST_F(PassManagerTest, Basic) { 206 if (!TM) 207 GTEST_SKIP(); 208 209 LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get()); 210 M->setDataLayout(TM->createDataLayout()); 211 212 LoopAnalysisManager LAM; 213 FunctionAnalysisManager FAM; 214 CGSCCAnalysisManager CGAM; 215 ModuleAnalysisManager MAM; 216 PassBuilder PB(TM.get()); 217 PB.registerModuleAnalyses(MAM); 218 PB.registerFunctionAnalyses(FAM); 219 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 220 221 FAM.registerPass([&] { return TestFunctionAnalysis(); }); 222 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 223 MAM.registerPass([&] { return MachineModuleAnalysis(LLVMTM); }); 224 MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 225 226 MachineFunctionAnalysisManager MFAM; 227 { 228 // Test move assignment. 229 MachineFunctionAnalysisManager NestedMFAM(FAM, MAM); 230 NestedMFAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 231 NestedMFAM.registerPass([&] { return TestMachineFunctionAnalysis(); }); 232 MFAM = std::move(NestedMFAM); 233 } 234 235 int Count = 0; 236 std::vector<int> BeforeInitialization[2]; 237 std::vector<int> BeforeFinalization[2]; 238 std::vector<int> TestMachineFunctionCount[2]; 239 std::vector<int> TestMachineModuleCount[2]; 240 241 MachineFunctionPassManager MFPM; 242 { 243 // Test move assignment. 244 MachineFunctionPassManager NestedMFPM; 245 NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[0])); 246 NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[0], 247 BeforeFinalization[0], 248 TestMachineFunctionCount[0])); 249 NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[1])); 250 NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], 251 BeforeFinalization[1], 252 TestMachineFunctionCount[1])); 253 MFPM = std::move(NestedMFPM); 254 } 255 256 ASSERT_FALSE(errorToBool(MFPM.run(*M, MFAM))); 257 258 // Check first machine module pass 259 EXPECT_EQ(1u, TestMachineModuleCount[0].size()); 260 EXPECT_EQ(3, TestMachineModuleCount[0][0]); 261 262 // Check first machine function pass 263 EXPECT_EQ(1u, BeforeInitialization[0].size()); 264 EXPECT_EQ(1, BeforeInitialization[0][0]); 265 EXPECT_EQ(3u, TestMachineFunctionCount[0].size()); 266 EXPECT_EQ(10, TestMachineFunctionCount[0][0]); 267 EXPECT_EQ(13, TestMachineFunctionCount[0][1]); 268 EXPECT_EQ(16, TestMachineFunctionCount[0][2]); 269 EXPECT_EQ(1u, BeforeFinalization[0].size()); 270 EXPECT_EQ(31, BeforeFinalization[0][0]); 271 272 // Check second machine module pass 273 EXPECT_EQ(1u, TestMachineModuleCount[1].size()); 274 EXPECT_EQ(17, TestMachineModuleCount[1][0]); 275 276 // Check second machine function pass 277 EXPECT_EQ(1u, BeforeInitialization[1].size()); 278 EXPECT_EQ(2, BeforeInitialization[1][0]); 279 EXPECT_EQ(3u, TestMachineFunctionCount[1].size()); 280 EXPECT_EQ(24, TestMachineFunctionCount[1][0]); 281 EXPECT_EQ(27, TestMachineFunctionCount[1][1]); 282 EXPECT_EQ(30, TestMachineFunctionCount[1][2]); 283 EXPECT_EQ(1u, BeforeFinalization[1].size()); 284 EXPECT_EQ(32, BeforeFinalization[1][0]); 285 286 EXPECT_EQ(32, Count); 287 288 // doInitialization returns error 289 Count = 10000; 290 MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], 291 BeforeFinalization[1], 292 TestMachineFunctionCount[1])); 293 std::string Message; 294 llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) { 295 Message = Error.getMessage(); 296 }); 297 EXPECT_EQ(Message, DoInitErrMsg); 298 299 // doFinalization returns error 300 Count = 1000; 301 MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], 302 BeforeFinalization[1], 303 TestMachineFunctionCount[1])); 304 llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) { 305 Message = Error.getMessage(); 306 }); 307 EXPECT_EQ(Message, DoFinalErrMsg); 308 } 309 310 } // namespace 311