1 //===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h" 10 #include "llvm/AsmParser/Parser.h" 11 #include "llvm/IR/Module.h" 12 #include "llvm/Passes/PassBuilder.h" 13 #include "llvm/ProfileData/InstrProf.h" 14 15 #include "gmock/gmock.h" 16 #include "gtest/gtest.h" 17 18 #include <tuple> 19 20 namespace { 21 22 using namespace llvm; 23 24 using testing::_; 25 using ::testing::DoDefault; 26 using ::testing::Invoke; 27 using ::testing::NotNull; 28 using ::testing::Ref; 29 using ::testing::Return; 30 using ::testing::Sequence; 31 using ::testing::Test; 32 using ::testing::TestParamInfo; 33 using ::testing::Values; 34 using ::testing::WithParamInterface; 35 36 template <typename Derived> class MockAnalysisHandleBase { 37 public: 38 class Analysis : public AnalysisInfoMixin<Analysis> { 39 public: 40 class Result { 41 public: 42 // Forward invalidation events to the mock handle. 43 bool invalidate(Module &M, const PreservedAnalyses &PA, 44 ModuleAnalysisManager::Invalidator &Inv) { 45 return Handle->invalidate(M, PA, Inv); 46 } 47 48 private: 49 explicit Result(Derived *Handle) : Handle(Handle) {} 50 51 friend MockAnalysisHandleBase; 52 Derived *Handle; 53 }; 54 55 Result run(Module &M, ModuleAnalysisManager &AM) { 56 return Handle->run(M, AM); 57 } 58 59 private: 60 friend AnalysisInfoMixin<Analysis>; 61 friend MockAnalysisHandleBase; 62 static inline AnalysisKey Key; 63 64 Derived *Handle; 65 66 explicit Analysis(Derived *Handle) : Handle(Handle) {} 67 }; 68 69 Analysis getAnalysis() { return Analysis(static_cast<Derived *>(this)); } 70 71 typename Analysis::Result getResult() { 72 return typename Analysis::Result(static_cast<Derived *>(this)); 73 } 74 75 protected: 76 void setDefaults() { 77 ON_CALL(static_cast<Derived &>(*this), run(_, _)) 78 .WillByDefault(Return(this->getResult())); 79 ON_CALL(static_cast<Derived &>(*this), invalidate(_, _, _)) 80 .WillByDefault(Invoke([](Module &M, const PreservedAnalyses &PA, 81 ModuleAnalysisManager::Invalidator &) { 82 auto PAC = PA.template getChecker<Analysis>(); 83 return !PAC.preserved() && 84 !PAC.template preservedSet<AllAnalysesOn<Module>>(); 85 })); 86 } 87 88 private: 89 friend Derived; 90 MockAnalysisHandleBase() = default; 91 }; 92 93 class MockModuleAnalysisHandle 94 : public MockAnalysisHandleBase<MockModuleAnalysisHandle> { 95 public: 96 MockModuleAnalysisHandle() { setDefaults(); } 97 98 MOCK_METHOD(typename Analysis::Result, run, 99 (Module &, ModuleAnalysisManager &)); 100 101 MOCK_METHOD(bool, invalidate, 102 (Module &, const PreservedAnalyses &, 103 ModuleAnalysisManager::Invalidator &)); 104 }; 105 106 struct PGOInstrumentationGenTest 107 : public Test, 108 WithParamInterface<std::tuple<StringRef, StringRef>> { 109 ModulePassManager MPM; 110 PassBuilder PB; 111 MockModuleAnalysisHandle MMAHandle; 112 LoopAnalysisManager LAM; 113 FunctionAnalysisManager FAM; 114 CGSCCAnalysisManager CGAM; 115 ModuleAnalysisManager MAM; 116 LLVMContext Context; 117 std::unique_ptr<Module> M; 118 119 PGOInstrumentationGenTest() { 120 MAM.registerPass([&] { return MMAHandle.getAnalysis(); }); 121 PB.registerModuleAnalyses(MAM); 122 PB.registerCGSCCAnalyses(CGAM); 123 PB.registerFunctionAnalyses(FAM); 124 PB.registerLoopAnalyses(LAM); 125 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 126 MPM.addPass( 127 RequireAnalysisPass<MockModuleAnalysisHandle::Analysis, Module>()); 128 MPM.addPass(PGOInstrumentationGen()); 129 } 130 131 void parseAssembly(const StringRef IR) { 132 SMDiagnostic Error; 133 M = parseAssemblyString(IR, Error, Context); 134 std::string ErrMsg; 135 raw_string_ostream OS(ErrMsg); 136 Error.print("", OS); 137 138 // A failure here means that the test itself is buggy. 139 if (!M) 140 report_fatal_error(ErrMsg.c_str()); 141 } 142 }; 143 144 static constexpr StringRef CodeWithFuncDefs = R"( 145 define i32 @f(i32 %n) { 146 entry: 147 ret i32 0 148 })"; 149 150 static constexpr StringRef CodeWithFuncDecls = R"( 151 declare i32 @f(i32); 152 )"; 153 154 static constexpr StringRef CodeWithGlobals = R"( 155 @foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f] 156 declare i32 @f(i32); 157 )"; 158 159 INSTANTIATE_TEST_SUITE_P( 160 PGOInstrumetationGenTestSuite, PGOInstrumentationGenTest, 161 Values(std::make_tuple(CodeWithFuncDefs, "instrument_function_defs"), 162 std::make_tuple(CodeWithFuncDecls, "instrument_function_decls"), 163 std::make_tuple(CodeWithGlobals, "instrument_globals")), 164 [](const TestParamInfo<PGOInstrumentationGenTest::ParamType> &Info) { 165 return std::get<1>(Info.param).str(); 166 }); 167 168 TEST_P(PGOInstrumentationGenTest, Instrumented) { 169 const StringRef Code = std::get<0>(GetParam()); 170 parseAssembly(Code); 171 172 ASSERT_THAT(M, NotNull()); 173 174 Sequence PassSequence; 175 EXPECT_CALL(MMAHandle, run(Ref(*M), _)) 176 .InSequence(PassSequence) 177 .WillOnce(DoDefault()); 178 EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _)) 179 .InSequence(PassSequence) 180 .WillOnce(DoDefault()); 181 182 MPM.run(*M, MAM); 183 184 const auto *IRInstrVar = 185 M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR)); 186 EXPECT_THAT(IRInstrVar, NotNull()); 187 EXPECT_FALSE(IRInstrVar->isDeclaration()); 188 } 189 190 } // end anonymous namespace 191