xref: /llvm-project/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp (revision 6047deb7c2aa94d9bc2b70b49799d22cce778bd4)
1 //===----- ProfDataUtils.cpp - Unit tests for ProfDataUtils ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BasicAliasAnalysis.h"
11 #include "llvm/Analysis/BlockFrequencyInfo.h"
12 #include "llvm/Analysis/BranchProbabilityInfo.h"
13 #include "llvm/Analysis/CFG.h"
14 #include "llvm/Analysis/DomTreeUpdater.h"
15 #include "llvm/Analysis/LoopInfo.h"
16 #include "llvm/Analysis/MemorySSA.h"
17 #include "llvm/Analysis/MemorySSAUpdater.h"
18 #include "llvm/Analysis/PostDominators.h"
19 #include "llvm/Analysis/TargetLibraryInfo.h"
20 #include "llvm/AsmParser/Parser.h"
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Dominators.h"
23 #include "llvm/IR/LLVMContext.h"
24 #include "llvm/IR/ProfDataUtils.h"
25 #include "llvm/Support/SourceMgr.h"
26 #include "llvm/Transforms/Utils/BreakCriticalEdges.h"
27 #include "gtest/gtest.h"
28 
29 using namespace llvm;
30 
31 static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
32   SMDiagnostic Err;
33   std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
34   if (!Mod)
35     Err.print("ProfDataUtilsTests", errs());
36   return Mod;
37 }
38 
39 TEST(ProfDataUtils, extractWeights) {
40   LLVMContext C;
41   std::unique_ptr<Module> M = parseIR(C, R"IR(
42 define void @foo(i1 %cond0) {
43 entry:
44   br i1 %cond0, label %bb0, label %bb1, !prof !1
45 bb0:
46  %0 = mul i32 1, 2
47  br label %bb1
48 bb1:
49   ret void
50 }
51 
52 !1 = !{!"branch_weights", i32 1, i32 100000}
53 )IR");
54   Function *F = M->getFunction("foo");
55   auto &Entry = F->getEntryBlock();
56   auto &I = Entry.front();
57   auto *Branch = dyn_cast<BranchInst>(&I);
58   EXPECT_NE(nullptr, Branch);
59   auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof);
60   EXPECT_NE(ProfileData, nullptr);
61   EXPECT_TRUE(hasProfMD(I));
62   SmallVector<uint32_t> Weights;
63   EXPECT_TRUE(extractBranchWeights(ProfileData, Weights));
64   EXPECT_EQ(Weights[0], 1U);
65   EXPECT_EQ(Weights[1], 100000U);
66   EXPECT_EQ(Weights.size(), 2U);
67 }
68 
69 TEST(ProfDataUtils, NoWeights) {
70   LLVMContext C;
71   std::unique_ptr<Module> M = parseIR(C, R"IR(
72 define void @foo(i1 %cond0) {
73 entry:
74   br i1 %cond0, label %bb0, label %bb1
75 bb0:
76  %0 = mul i32 1, 2
77  br label %bb1
78 bb1:
79   ret void
80 }
81 )IR");
82   Function *F = M->getFunction("foo");
83   auto &Entry = F->getEntryBlock();
84   auto &I = Entry.front();
85   auto *Branch = dyn_cast<BranchInst>(&I);
86   EXPECT_NE(nullptr, Branch);
87   auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof);
88   EXPECT_EQ(ProfileData, nullptr);
89   EXPECT_FALSE(hasProfMD(I));
90   SmallVector<uint32_t> Weights;
91   EXPECT_FALSE(extractBranchWeights(ProfileData, Weights));
92   EXPECT_EQ(Weights.size(), 0U);
93 }
94