1 //===- llvm/unittest/CodeGen/PassManager.cpp - PassManager tests ----------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/ADT/Triple.h" 10 #include "llvm/Analysis/CGSCCPassManager.h" 11 #include "llvm/Analysis/LoopAnalysisManager.h" 12 #include "llvm/AsmParser/Parser.h" 13 #include "llvm/CodeGen/MachineModuleInfo.h" 14 #include "llvm/CodeGen/MachinePassManager.h" 15 #include "llvm/IR/LLVMContext.h" 16 #include "llvm/IR/Module.h" 17 #include "llvm/Passes/PassBuilder.h" 18 #include "llvm/Support/Host.h" 19 #include "llvm/Support/SourceMgr.h" 20 #include "llvm/Support/TargetRegistry.h" 21 #include "llvm/Support/TargetSelect.h" 22 #include "llvm/Target/TargetMachine.h" 23 #include "gtest/gtest.h" 24 25 using namespace llvm; 26 27 namespace { 28 29 class TestFunctionAnalysis : public AnalysisInfoMixin<TestFunctionAnalysis> { 30 public: 31 struct Result { 32 Result(int Count) : InstructionCount(Count) {} 33 int InstructionCount; 34 }; 35 36 /// Run the analysis pass over the function and return a result. 37 Result run(Function &F, FunctionAnalysisManager &AM) { 38 int Count = 0; 39 for (Function::iterator BBI = F.begin(), BBE = F.end(); BBI != BBE; ++BBI) 40 for (BasicBlock::iterator II = BBI->begin(), IE = BBI->end(); II != IE; 41 ++II) 42 ++Count; 43 return Result(Count); 44 } 45 46 private: 47 friend AnalysisInfoMixin<TestFunctionAnalysis>; 48 static AnalysisKey Key; 49 }; 50 51 AnalysisKey TestFunctionAnalysis::Key; 52 53 class TestMachineFunctionAnalysis 54 : public AnalysisInfoMixin<TestMachineFunctionAnalysis> { 55 public: 56 struct Result { 57 Result(int Count) : InstructionCount(Count) {} 58 int InstructionCount; 59 }; 60 61 /// Run the analysis pass over the machine function and return a result. 62 Result run(MachineFunction &MF, MachineFunctionAnalysisManager::Base &AM) { 63 auto &MFAM = static_cast<MachineFunctionAnalysisManager &>(AM); 64 // Query function analysis result. 65 TestFunctionAnalysis::Result &FAR = 66 MFAM.getResult<TestFunctionAnalysis>(MF.getFunction()); 67 // + 5 68 return FAR.InstructionCount; 69 } 70 71 private: 72 friend AnalysisInfoMixin<TestMachineFunctionAnalysis>; 73 static AnalysisKey Key; 74 }; 75 76 AnalysisKey TestMachineFunctionAnalysis::Key; 77 78 const std::string DoInitErrMsg = "doInitialization failed"; 79 const std::string DoFinalErrMsg = "doFinalization failed"; 80 81 struct TestMachineFunctionPass : public PassInfoMixin<TestMachineFunctionPass> { 82 TestMachineFunctionPass(int &Count, std::vector<int> &BeforeInitialization, 83 std::vector<int> &BeforeFinalization, 84 std::vector<int> &MachineFunctionPassCount) 85 : Count(Count), BeforeInitialization(BeforeInitialization), 86 BeforeFinalization(BeforeFinalization), 87 MachineFunctionPassCount(MachineFunctionPassCount) {} 88 89 Error doInitialization(Module &M, MachineFunctionAnalysisManager &MFAM) { 90 // Force doInitialization fail by starting with big `Count`. 91 if (Count > 10000) 92 return make_error<StringError>(DoInitErrMsg, inconvertibleErrorCode()); 93 94 // + 1 95 ++Count; 96 BeforeInitialization.push_back(Count); 97 return Error::success(); 98 } 99 Error doFinalization(Module &M, MachineFunctionAnalysisManager &MFAM) { 100 // Force doFinalization fail by starting with big `Count`. 101 if (Count > 1000) 102 return make_error<StringError>(DoFinalErrMsg, inconvertibleErrorCode()); 103 104 // + 1 105 ++Count; 106 BeforeFinalization.push_back(Count); 107 return Error::success(); 108 } 109 110 PreservedAnalyses run(MachineFunction &MF, 111 MachineFunctionAnalysisManager &MFAM) { 112 // Query function analysis result. 113 TestFunctionAnalysis::Result &FAR = 114 MFAM.getResult<TestFunctionAnalysis>(MF.getFunction()); 115 // 3 + 1 + 1 = 5 116 Count += FAR.InstructionCount; 117 118 // Query module analysis result. 119 MachineModuleInfo &MMI = 120 MFAM.getResult<MachineModuleAnalysis>(*MF.getFunction().getParent()); 121 // 1 + 1 + 1 = 3 122 Count += (MMI.getModule() == MF.getFunction().getParent()); 123 124 // Query machine function analysis result. 125 TestMachineFunctionAnalysis::Result &MFAR = 126 MFAM.getResult<TestMachineFunctionAnalysis>(MF); 127 // 3 + 1 + 1 = 5 128 Count += MFAR.InstructionCount; 129 130 MachineFunctionPassCount.push_back(Count); 131 132 return PreservedAnalyses::none(); 133 } 134 135 int &Count; 136 std::vector<int> &BeforeInitialization; 137 std::vector<int> &BeforeFinalization; 138 std::vector<int> &MachineFunctionPassCount; 139 }; 140 141 struct TestMachineModulePass : public PassInfoMixin<TestMachineModulePass> { 142 TestMachineModulePass(int &Count, std::vector<int> &MachineModulePassCount) 143 : Count(Count), MachineModulePassCount(MachineModulePassCount) {} 144 145 Error run(Module &M, MachineFunctionAnalysisManager &MFAM) { 146 MachineModuleInfo &MMI = MFAM.getResult<MachineModuleAnalysis>(M); 147 // + 1 148 Count += (MMI.getModule() == &M); 149 MachineModulePassCount.push_back(Count); 150 return Error::success(); 151 } 152 153 PreservedAnalyses run(MachineFunction &MF, 154 MachineFunctionAnalysisManager &AM) { 155 llvm_unreachable( 156 "This should never be reached because this is machine module pass"); 157 } 158 159 int &Count; 160 std::vector<int> &MachineModulePassCount; 161 }; 162 163 std::unique_ptr<Module> parseIR(LLVMContext &Context, const char *IR) { 164 SMDiagnostic Err; 165 return parseAssemblyString(IR, Err, Context); 166 } 167 168 class PassManagerTest : public ::testing::Test { 169 protected: 170 LLVMContext Context; 171 std::unique_ptr<Module> M; 172 std::unique_ptr<TargetMachine> TM; 173 174 public: 175 PassManagerTest() 176 : M(parseIR(Context, "define void @f() {\n" 177 "entry:\n" 178 " call void @g()\n" 179 " call void @h()\n" 180 " ret void\n" 181 "}\n" 182 "define void @g() {\n" 183 " ret void\n" 184 "}\n" 185 "define void @h() {\n" 186 " ret void\n" 187 "}\n")) { 188 // MachineModuleAnalysis needs a TargetMachine instance. 189 llvm::InitializeAllTargets(); 190 191 std::string TripleName = Triple::normalize(sys::getDefaultTargetTriple()); 192 std::string Error; 193 const Target *TheTarget = 194 TargetRegistry::lookupTarget(TripleName, Error); 195 if (!TheTarget) 196 return; 197 198 TargetOptions Options; 199 TM.reset(TheTarget->createTargetMachine(TripleName, "", "", 200 Options, None)); 201 } 202 }; 203 204 TEST_F(PassManagerTest, Basic) { 205 if (!TM) 206 GTEST_SKIP(); 207 208 LLVMTargetMachine *LLVMTM = static_cast<LLVMTargetMachine *>(TM.get()); 209 M->setDataLayout(TM->createDataLayout()); 210 211 LoopAnalysisManager LAM; 212 FunctionAnalysisManager FAM; 213 CGSCCAnalysisManager CGAM; 214 ModuleAnalysisManager MAM; 215 PassBuilder PB(TM.get()); 216 PB.registerModuleAnalyses(MAM); 217 PB.registerFunctionAnalyses(FAM); 218 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 219 220 FAM.registerPass([&] { return TestFunctionAnalysis(); }); 221 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 222 MAM.registerPass([&] { return MachineModuleAnalysis(LLVMTM); }); 223 MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 224 225 MachineFunctionAnalysisManager MFAM; 226 { 227 // Test move assignment. 228 MachineFunctionAnalysisManager NestedMFAM(FAM, MAM); 229 NestedMFAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 230 NestedMFAM.registerPass([&] { return TestMachineFunctionAnalysis(); }); 231 MFAM = std::move(NestedMFAM); 232 } 233 234 int Count = 0; 235 std::vector<int> BeforeInitialization[2]; 236 std::vector<int> BeforeFinalization[2]; 237 std::vector<int> TestMachineFunctionCount[2]; 238 std::vector<int> TestMachineModuleCount[2]; 239 240 MachineFunctionPassManager MFPM; 241 { 242 // Test move assignment. 243 MachineFunctionPassManager NestedMFPM; 244 NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[0])); 245 NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[0], 246 BeforeFinalization[0], 247 TestMachineFunctionCount[0])); 248 NestedMFPM.addPass(TestMachineModulePass(Count, TestMachineModuleCount[1])); 249 NestedMFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], 250 BeforeFinalization[1], 251 TestMachineFunctionCount[1])); 252 MFPM = std::move(NestedMFPM); 253 } 254 255 ASSERT_FALSE(errorToBool(MFPM.run(*M, MFAM))); 256 257 // Check first machine module pass 258 EXPECT_EQ(1u, TestMachineModuleCount[0].size()); 259 EXPECT_EQ(3, TestMachineModuleCount[0][0]); 260 261 // Check first machine function pass 262 EXPECT_EQ(1u, BeforeInitialization[0].size()); 263 EXPECT_EQ(1, BeforeInitialization[0][0]); 264 EXPECT_EQ(3u, TestMachineFunctionCount[0].size()); 265 EXPECT_EQ(10, TestMachineFunctionCount[0][0]); 266 EXPECT_EQ(13, TestMachineFunctionCount[0][1]); 267 EXPECT_EQ(16, TestMachineFunctionCount[0][2]); 268 EXPECT_EQ(1u, BeforeFinalization[0].size()); 269 EXPECT_EQ(31, BeforeFinalization[0][0]); 270 271 // Check second machine module pass 272 EXPECT_EQ(1u, TestMachineModuleCount[1].size()); 273 EXPECT_EQ(17, TestMachineModuleCount[1][0]); 274 275 // Check second machine function pass 276 EXPECT_EQ(1u, BeforeInitialization[1].size()); 277 EXPECT_EQ(2, BeforeInitialization[1][0]); 278 EXPECT_EQ(3u, TestMachineFunctionCount[1].size()); 279 EXPECT_EQ(24, TestMachineFunctionCount[1][0]); 280 EXPECT_EQ(27, TestMachineFunctionCount[1][1]); 281 EXPECT_EQ(30, TestMachineFunctionCount[1][2]); 282 EXPECT_EQ(1u, BeforeFinalization[1].size()); 283 EXPECT_EQ(32, BeforeFinalization[1][0]); 284 285 EXPECT_EQ(32, Count); 286 287 // doInitialization returns error 288 Count = 10000; 289 MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], 290 BeforeFinalization[1], 291 TestMachineFunctionCount[1])); 292 std::string Message; 293 llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) { 294 Message = Error.getMessage(); 295 }); 296 EXPECT_EQ(Message, DoInitErrMsg); 297 298 // doFinalization returns error 299 Count = 1000; 300 MFPM.addPass(TestMachineFunctionPass(Count, BeforeInitialization[1], 301 BeforeFinalization[1], 302 TestMachineFunctionCount[1])); 303 llvm::handleAllErrors(MFPM.run(*M, MFAM), [&](llvm::StringError &Error) { 304 Message = Error.getMessage(); 305 }); 306 EXPECT_EQ(Message, DoFinalErrMsg); 307 } 308 309 } // namespace 310