xref: /llvm-project/llvm/unittests/Transforms/Utils/ProfDataUtilTest.cpp (revision 4d5906e0bf27b7aea451cf60e0d38a01f9b96085)
1 //===----- ProfDataUtils.cpp - Unit tests for ProfDataUtils ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/AsmParser/Parser.h"
10 #include "llvm/IR/BasicBlock.h"
11 #include "llvm/IR/Function.h"
12 #include "llvm/IR/Instructions.h"
13 #include "llvm/IR/LLVMContext.h"
14 #include "llvm/IR/Module.h"
15 #include "llvm/IR/ProfDataUtils.h"
16 #include "llvm/Support/SourceMgr.h"
17 #include "gtest/gtest.h"
18 
19 using namespace llvm;
20 
parseIR(LLVMContext & C,const char * IR)21 static std::unique_ptr<Module> parseIR(LLVMContext &C, const char *IR) {
22   SMDiagnostic Err;
23   std::unique_ptr<Module> Mod = parseAssemblyString(IR, Err, C);
24   if (!Mod)
25     Err.print("ProfDataUtilsTests", errs());
26   return Mod;
27 }
28 
TEST(ProfDataUtils,extractWeights)29 TEST(ProfDataUtils, extractWeights) {
30   LLVMContext C;
31   std::unique_ptr<Module> M = parseIR(C, R"IR(
32 define void @foo(i1 %cond0) {
33 entry:
34   br i1 %cond0, label %bb0, label %bb1, !prof !1
35 bb0:
36  %0 = mul i32 1, 2
37  br label %bb1
38 bb1:
39   ret void
40 }
41 
42 !1 = !{!"branch_weights", i32 1, i32 100000}
43 )IR");
44   Function *F = M->getFunction("foo");
45   auto &Entry = F->getEntryBlock();
46   auto &I = Entry.front();
47   auto *Branch = dyn_cast<BranchInst>(&I);
48   EXPECT_NE(nullptr, Branch);
49   auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof);
50   EXPECT_NE(ProfileData, nullptr);
51   EXPECT_TRUE(hasProfMD(I));
52   SmallVector<uint32_t> Weights;
53   EXPECT_TRUE(extractBranchWeights(ProfileData, Weights));
54   EXPECT_EQ(Weights[0], 1U);
55   EXPECT_EQ(Weights[1], 100000U);
56   EXPECT_EQ(Weights.size(), 2U);
57 }
58 
TEST(ProfDataUtils,NoWeights)59 TEST(ProfDataUtils, NoWeights) {
60   LLVMContext C;
61   std::unique_ptr<Module> M = parseIR(C, R"IR(
62 define void @foo(i1 %cond0) {
63 entry:
64   br i1 %cond0, label %bb0, label %bb1
65 bb0:
66  %0 = mul i32 1, 2
67  br label %bb1
68 bb1:
69   ret void
70 }
71 )IR");
72   Function *F = M->getFunction("foo");
73   auto &Entry = F->getEntryBlock();
74   auto &I = Entry.front();
75   auto *Branch = dyn_cast<BranchInst>(&I);
76   EXPECT_NE(nullptr, Branch);
77   auto *ProfileData = Branch->getMetadata(LLVMContext::MD_prof);
78   EXPECT_EQ(ProfileData, nullptr);
79   EXPECT_FALSE(hasProfMD(I));
80   SmallVector<uint32_t> Weights;
81   EXPECT_FALSE(extractBranchWeights(ProfileData, Weights));
82   EXPECT_EQ(Weights.size(), 0U);
83 }
84