xref: /llvm-project/llvm/unittests/Transforms/Instrumentation/PGOInstrumentationTest.cpp (revision e36b22f3bf45a23d31b569e53d22b98714cf00e3)
1 //===- PGOInstrumentationTest.cpp - Instrumentation unit tests ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/Passes/PassBuilder.h"
13 #include "llvm/ProfileData/InstrProf.h"
14 
15 #include "gmock/gmock.h"
16 #include "gtest/gtest.h"
17 
18 #include <tuple>
19 
20 namespace {
21 
22 using namespace llvm;
23 
24 using testing::_;
25 using ::testing::DoDefault;
26 using ::testing::Invoke;
27 using ::testing::NotNull;
28 using ::testing::Ref;
29 using ::testing::Return;
30 using ::testing::Sequence;
31 using ::testing::Test;
32 using ::testing::TestParamInfo;
33 using ::testing::Values;
34 using ::testing::WithParamInterface;
35 
36 template <typename Derived> class MockAnalysisHandleBase {
37 public:
38   class Analysis : public AnalysisInfoMixin<Analysis> {
39   public:
40     class Result {
41     public:
42       // Forward invalidation events to the mock handle.
43       bool invalidate(Module &M, const PreservedAnalyses &PA,
44                       ModuleAnalysisManager::Invalidator &Inv) {
45         return Handle->invalidate(M, PA, Inv);
46       }
47 
48     private:
49       explicit Result(Derived *Handle) : Handle(Handle) {}
50 
51       friend MockAnalysisHandleBase;
52       Derived *Handle;
53     };
54 
55     Result run(Module &M, ModuleAnalysisManager &AM) {
56       return Handle->run(M, AM);
57     }
58 
59   private:
60     friend AnalysisInfoMixin<Analysis>;
61     friend MockAnalysisHandleBase;
62     static inline AnalysisKey Key;
63 
64     Derived *Handle;
65 
66     explicit Analysis(Derived *Handle) : Handle(Handle) {}
67   };
68 
69   Analysis getAnalysis() { return Analysis(static_cast<Derived *>(this)); }
70 
71   typename Analysis::Result getResult() {
72     return typename Analysis::Result(static_cast<Derived *>(this));
73   }
74 
75 protected:
76   void setDefaults() {
77     ON_CALL(static_cast<Derived &>(*this), run(_, _))
78         .WillByDefault(Return(this->getResult()));
79     ON_CALL(static_cast<Derived &>(*this), invalidate(_, _, _))
80         .WillByDefault(Invoke([](Module &M, const PreservedAnalyses &PA,
81                                  ModuleAnalysisManager::Invalidator &) {
82           auto PAC = PA.template getChecker<Analysis>();
83           return !PAC.preserved() &&
84                  !PAC.template preservedSet<AllAnalysesOn<Module>>();
85         }));
86   }
87 
88 private:
89   friend Derived;
90   MockAnalysisHandleBase() = default;
91 };
92 
93 class MockModuleAnalysisHandle
94     : public MockAnalysisHandleBase<MockModuleAnalysisHandle> {
95 public:
96   MockModuleAnalysisHandle() { setDefaults(); }
97 
98   MOCK_METHOD(typename Analysis::Result, run,
99               (Module &, ModuleAnalysisManager &));
100 
101   MOCK_METHOD(bool, invalidate,
102               (Module &, const PreservedAnalyses &,
103                ModuleAnalysisManager::Invalidator &));
104 };
105 
106 struct PGOInstrumentationGenTest
107     : public Test,
108       WithParamInterface<std::tuple<StringRef, StringRef>> {
109   ModulePassManager MPM;
110   PassBuilder PB;
111   MockModuleAnalysisHandle MMAHandle;
112   LoopAnalysisManager LAM;
113   FunctionAnalysisManager FAM;
114   CGSCCAnalysisManager CGAM;
115   ModuleAnalysisManager MAM;
116   LLVMContext Context;
117   std::unique_ptr<Module> M;
118 
119   PGOInstrumentationGenTest() {
120     MAM.registerPass([&] { return MMAHandle.getAnalysis(); });
121     PB.registerModuleAnalyses(MAM);
122     PB.registerCGSCCAnalyses(CGAM);
123     PB.registerFunctionAnalyses(FAM);
124     PB.registerLoopAnalyses(LAM);
125     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
126     MPM.addPass(
127         RequireAnalysisPass<MockModuleAnalysisHandle::Analysis, Module>());
128     MPM.addPass(PGOInstrumentationGen());
129   }
130 
131   void parseAssembly(const StringRef IR) {
132     SMDiagnostic Error;
133     M = parseAssemblyString(IR, Error, Context);
134     std::string ErrMsg;
135     raw_string_ostream OS(ErrMsg);
136     Error.print("", OS);
137 
138     // A failure here means that the test itself is buggy.
139     if (!M)
140       report_fatal_error(ErrMsg.c_str());
141   }
142 };
143 
144 static constexpr StringRef CodeWithFuncDefs = R"(
145   define i32 @f(i32 %n) {
146   entry:
147     ret i32 0
148   })";
149 
150 static constexpr StringRef CodeWithFuncDecls = R"(
151   declare i32 @f(i32);
152 )";
153 
154 static constexpr StringRef CodeWithGlobals = R"(
155   @foo.table = internal unnamed_addr constant [1 x ptr] [ptr @f]
156   declare i32 @f(i32);
157 )";
158 
159 INSTANTIATE_TEST_SUITE_P(
160     PGOInstrumetationGenTestSuite, PGOInstrumentationGenTest,
161     Values(std::make_tuple(CodeWithFuncDefs, "instrument_function_defs"),
162            std::make_tuple(CodeWithFuncDecls, "instrument_function_decls"),
163            std::make_tuple(CodeWithGlobals, "instrument_globals")),
164     [](const TestParamInfo<PGOInstrumentationGenTest::ParamType> &Info) {
165       return std::get<1>(Info.param).str();
166     });
167 
168 TEST_P(PGOInstrumentationGenTest, Instrumented) {
169   const StringRef Code = std::get<0>(GetParam());
170   parseAssembly(Code);
171 
172   ASSERT_THAT(M, NotNull());
173 
174   Sequence PassSequence;
175   EXPECT_CALL(MMAHandle, run(Ref(*M), _))
176       .InSequence(PassSequence)
177       .WillOnce(DoDefault());
178   EXPECT_CALL(MMAHandle, invalidate(Ref(*M), _, _))
179       .InSequence(PassSequence)
180       .WillOnce(DoDefault());
181 
182   MPM.run(*M, MAM);
183 
184   const auto *IRInstrVar =
185       M->getNamedGlobal(INSTR_PROF_QUOTE(INSTR_PROF_RAW_VERSION_VAR));
186   EXPECT_THAT(IRInstrVar, NotNull());
187   EXPECT_FALSE(IRInstrVar->isDeclaration());
188 }
189 
190 } // end anonymous namespace
191