xref: /llvm-project/llvm/unittests/Analysis/InlineCostTest.cpp (revision 36c6632eb43bf67e19c8a6a21981cf66e06389b4)
1 //===- InlineCostTest.cpp - test for InlineCost ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/InlineCost.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/Analysis/InlineModelFeatureMaps.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/AsmParser/Parser.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "gtest/gtest.h"
21 
22 namespace {
23 
24 using namespace llvm;
25 
getCallInFunction(Function * F)26 CallBase *getCallInFunction(Function *F) {
27   for (auto &I : instructions(F)) {
28     if (auto *CB = dyn_cast<llvm::CallBase>(&I))
29       return CB;
30   }
31   return nullptr;
32 }
33 
getInliningCostFeaturesForCall(CallBase & CB)34 std::optional<InlineCostFeatures> getInliningCostFeaturesForCall(CallBase &CB) {
35   ModuleAnalysisManager MAM;
36   FunctionAnalysisManager FAM;
37   FAM.registerPass([&] { return TargetIRAnalysis(); });
38   FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
39   FAM.registerPass([&] { return AssumptionAnalysis(); });
40   MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
41 
42   MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
43   FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
44 
45   ModulePassManager MPM;
46   MPM.run(*CB.getModule(), MAM);
47 
48   auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
49     return FAM.getResult<AssumptionAnalysis>(F);
50   };
51   auto &TIR = FAM.getResult<TargetIRAnalysis>(*CB.getFunction());
52 
53   return getInliningCostFeatures(CB, TIR, GetAssumptionCache);
54 }
55 
56 // Tests that we can retrieve the CostFeatures without an error
TEST(InlineCostTest,CostFeatures)57 TEST(InlineCostTest, CostFeatures) {
58   const auto *const IR = R"IR(
59 define i32 @f(i32) {
60   ret i32 4
61 }
62 
63 define i32 @g(i32) {
64   %2 = call i32 @f(i32 0)
65   ret i32 %2
66 }
67 )IR";
68 
69   LLVMContext C;
70   SMDiagnostic Err;
71   std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
72   ASSERT_TRUE(M);
73 
74   auto *G = M->getFunction("g");
75   ASSERT_TRUE(G);
76 
77   // find the call to f in g
78   CallBase *CB = getCallInFunction(G);
79   ASSERT_TRUE(CB);
80 
81   const auto Features = getInliningCostFeaturesForCall(*CB);
82 
83   // Check that the optional is not empty
84   ASSERT_TRUE(Features);
85 }
86 
87 // Tests the calculated SROA cost
TEST(InlineCostTest,SROACost)88 TEST(InlineCostTest, SROACost) {
89   using namespace llvm;
90 
91   const auto *const IR = R"IR(
92 define void @f_savings(ptr %var) {
93   %load = load i32, ptr %var
94   %inc = add i32 %load, 1
95   store i32 %inc, ptr %var
96   ret void
97 }
98 
99 define void @g_savings(i32) {
100   %var = alloca i32
101   call void @f_savings(ptr %var)
102   ret void
103 }
104 
105 define void @f_losses(ptr %var) {
106   %load = load i32, ptr %var
107   %inc = add i32 %load, 1
108   store i32 %inc, ptr %var
109   call void @prevent_sroa(ptr %var)
110   ret void
111 }
112 
113 define void @g_losses(i32) {
114   %var = alloca i32
115   call void @f_losses(ptr %var)
116   ret void
117 }
118 
119 declare void @prevent_sroa(ptr)
120 )IR";
121 
122   LLVMContext C;
123   SMDiagnostic Err;
124   std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
125   ASSERT_TRUE(M);
126 
127   const int DefaultInstCost = 5;
128   const int DefaultAllocaCost = 0;
129 
130   const char *GName[] = {"g_savings", "g_losses", nullptr};
131   const int Savings[] = {2 * DefaultInstCost + DefaultAllocaCost, 0};
132   const int Losses[] = {0, 2 * DefaultInstCost + DefaultAllocaCost};
133 
134   for (unsigned i = 0; GName[i]; ++i) {
135     auto *G = M->getFunction(GName[i]);
136     ASSERT_TRUE(G);
137 
138     // find the call to f in g
139     CallBase *CB = getCallInFunction(G);
140     ASSERT_TRUE(CB);
141 
142     const auto Features = getInliningCostFeaturesForCall(*CB);
143     ASSERT_TRUE(Features);
144 
145     // Check the predicted SROA cost
146     auto GetFeature = [&](InlineCostFeatureIndex I) {
147       return (*Features)[static_cast<size_t>(I)];
148     };
149     ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings), Savings[i]);
150     ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses), Losses[i]);
151   }
152 }
153 
154 } // namespace
155