xref: /llvm-project/llvm/unittests/Analysis/CtxProfAnalysisTest.cpp (revision 82266d3a2b33f49165c0f24d3db5ea9875cc706c)
1 //===--- CtxProfAnalysisTest.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/CtxProfAnalysis.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/CGSCCPassManager.h"
13 #include "llvm/Analysis/LoopAnalysisManager.h"
14 #include "llvm/AsmParser/Parser.h"
15 #include "llvm/IR/Analysis.h"
16 #include "llvm/IR/Module.h"
17 #include "llvm/IR/PassInstrumentation.h"
18 #include "llvm/IR/PassManager.h"
19 #include "llvm/Passes/PassBuilder.h"
20 #include "llvm/Support/SourceMgr.h"
21 #include "llvm/Transforms/Instrumentation/PGOInstrumentation.h"
22 #include "gmock/gmock.h"
23 #include "gtest/gtest.h"
24 
25 using namespace llvm;
26 
27 namespace {
28 
29 class CtxProfAnalysisTest : public testing::Test {
30   static constexpr auto *IR = R"IR(
31 declare void @bar()
32 
33 define private void @foo(i32 %a, ptr %fct) #0 !guid !0 {
34   %t = icmp eq i32 %a, 0
35   br i1 %t, label %yes, label %no
36 yes:
37   call void %fct(i32 %a)
38   br label %exit
39 no:
40   call void @bar()
41   br label %exit
42 exit:
43   ret void
44 }
45 
46 define void @an_entrypoint(i32 %a) {
47   %t = icmp eq i32 %a, 0
48   br i1 %t, label %yes, label %no
49 
50 yes:
51   call void @foo(i32 1, ptr null)
52   ret void
53 no:
54   ret void
55 }
56 
57 define void @another_entrypoint_no_callees(i32 %a) {
58   %t = icmp eq i32 %a, 0
59   br i1 %t, label %yes, label %no
60 
61 yes:
62   ret void
63 no:
64   ret void
65 }
66 
67 define void @inlineasm() {
68   call void asm "nop", ""()
69   ret void
70 }
71 
72 attributes #0 = { noinline }
73 !0 = !{ i64 11872291593386833696 }
74 )IR";
75 
76 protected:
77   LLVMContext C;
78   PassBuilder PB;
79   ModuleAnalysisManager MAM;
80   FunctionAnalysisManager FAM;
81   CGSCCAnalysisManager CGAM;
82   LoopAnalysisManager LAM;
83   std::unique_ptr<Module> M;
84 
85   void SetUp() override {
86     SMDiagnostic Err;
87     M = parseAssemblyString(IR, Err, C);
88     ASSERT_TRUE(!!M);
89   }
90 
91 public:
92   CtxProfAnalysisTest() {
93     PB.registerModuleAnalyses(MAM);
94     PB.registerCGSCCAnalyses(CGAM);
95     PB.registerFunctionAnalyses(FAM);
96     PB.registerLoopAnalyses(LAM);
97     PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
98   }
99 };
100 
101 TEST_F(CtxProfAnalysisTest, GetCallsiteIDTest) {
102   ModulePassManager MPM;
103   MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF));
104   EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved());
105   auto *F = M->getFunction("foo");
106   ASSERT_NE(F, nullptr);
107   std::vector<uint32_t> InsValues;
108 
109   for (auto &BB : *F)
110     for (auto &I : BB)
111       if (auto *CB = dyn_cast<CallBase>(&I)) {
112         // Skip instrumentation inserted intrinsics.
113         if (CB->getCalledFunction() && CB->getCalledFunction()->isIntrinsic())
114           continue;
115         auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(*CB);
116         ASSERT_NE(Ins, nullptr);
117         InsValues.push_back(Ins->getIndex()->getZExtValue());
118       }
119 
120   EXPECT_THAT(InsValues, testing::ElementsAre(0, 1));
121 }
122 
123 TEST_F(CtxProfAnalysisTest, GetCallsiteIDInlineAsmTest) {
124   ModulePassManager MPM;
125   MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF));
126   EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved());
127   auto *F = M->getFunction("inlineasm");
128   ASSERT_NE(F, nullptr);
129   std::vector<const Instruction *> InsValues;
130 
131   for (auto &BB : *F)
132     for (auto &I : BB)
133       if (auto *CB = dyn_cast<CallBase>(&I)) {
134         // Skip instrumentation inserted intrinsics.
135         if (CB->getCalledFunction() && CB->getCalledFunction()->isIntrinsic())
136           continue;
137         auto *Ins = CtxProfAnalysis::getCallsiteInstrumentation(*CB);
138         InsValues.push_back(Ins);
139       }
140 
141   EXPECT_THAT(InsValues, testing::ElementsAre(nullptr));
142 }
143 
144 TEST_F(CtxProfAnalysisTest, GetCallsiteIDNegativeTest) {
145   auto *F = M->getFunction("foo");
146   ASSERT_NE(F, nullptr);
147   CallBase *FirstCall = nullptr;
148   for (auto &BB : *F)
149     for (auto &I : BB)
150       if (auto *CB = dyn_cast<CallBase>(&I)) {
151         if (CB->isIndirectCall() || !CB->getCalledFunction()->isIntrinsic()) {
152           FirstCall = CB;
153           break;
154         }
155       }
156   ASSERT_NE(FirstCall, nullptr);
157   auto *IndIns = CtxProfAnalysis::getCallsiteInstrumentation(*FirstCall);
158   EXPECT_EQ(IndIns, nullptr);
159 }
160 
161 TEST_F(CtxProfAnalysisTest, GetBBIDTest) {
162   ModulePassManager MPM;
163   MPM.addPass(PGOInstrumentationGen(PGOInstrumentationType::CTXPROF));
164   EXPECT_FALSE(MPM.run(*M, MAM).areAllPreserved());
165   auto *F = M->getFunction("foo");
166   ASSERT_NE(F, nullptr);
167   std::map<std::string, int> BBNameAndID;
168 
169   for (auto &BB : *F) {
170     auto *Ins = CtxProfAnalysis::getBBInstrumentation(BB);
171     if (Ins)
172       BBNameAndID[BB.getName().str()] =
173           static_cast<int>(Ins->getIndex()->getZExtValue());
174     else
175       BBNameAndID[BB.getName().str()] = -1;
176   }
177 
178   EXPECT_THAT(BBNameAndID,
179               testing::UnorderedElementsAre(
180                   testing::Pair("", 0), testing::Pair("yes", 1),
181                   testing::Pair("no", -1), testing::Pair("exit", -1)));
182 }
183 } // namespace
184