xref: /llvm-project/llvm/unittests/Analysis/InlineCostTest.cpp (revision cc8a346e3fa362c2b1319cca9883182fbf36b6db)
1 //===- InlineCostTest.cpp - test for InlineCost ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/InlineCost.h"
10 #include "llvm/Analysis/AssumptionCache.h"
11 #include "llvm/Analysis/InlineModelFeatureMaps.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/AsmParser/Parser.h"
14 #include "llvm/IR/InstIterator.h"
15 #include "llvm/IR/Instructions.h"
16 #include "llvm/IR/LLVMContext.h"
17 #include "llvm/IR/Module.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "gtest/gtest.h"
20 
21 namespace {
22 
23 using namespace llvm;
24 
25 CallBase *getCallInFunction(Function *F) {
26   for (auto &I : instructions(F)) {
27     if (auto *CB = dyn_cast<llvm::CallBase>(&I))
28       return CB;
29   }
30   return nullptr;
31 }
32 
33 std::optional<InlineCostFeatures> getInliningCostFeaturesForCall(CallBase &CB) {
34   ModuleAnalysisManager MAM;
35   FunctionAnalysisManager FAM;
36   FAM.registerPass([&] { return TargetIRAnalysis(); });
37   FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
38   FAM.registerPass([&] { return AssumptionAnalysis(); });
39   MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
40 
41   MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
42   FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
43 
44   ModulePassManager MPM;
45   MPM.run(*CB.getModule(), MAM);
46 
47   auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
48     return FAM.getResult<AssumptionAnalysis>(F);
49   };
50   auto &TIR = FAM.getResult<TargetIRAnalysis>(*CB.getFunction());
51 
52   return getInliningCostFeatures(CB, TIR, GetAssumptionCache);
53 }
54 
55 // Tests that we can retrieve the CostFeatures without an error
56 TEST(InlineCostTest, CostFeatures) {
57   const auto *const IR = R"IR(
58 define i32 @f(i32) {
59   ret i32 4
60 }
61 
62 define i32 @g(i32) {
63   %2 = call i32 @f(i32 0)
64   ret i32 %2
65 }
66 )IR";
67 
68   LLVMContext C;
69   SMDiagnostic Err;
70   std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
71   ASSERT_TRUE(M);
72 
73   auto *G = M->getFunction("g");
74   ASSERT_TRUE(G);
75 
76   // find the call to f in g
77   CallBase *CB = getCallInFunction(G);
78   ASSERT_TRUE(CB);
79 
80   const auto Features = getInliningCostFeaturesForCall(*CB);
81 
82   // Check that the optional is not empty
83   ASSERT_TRUE(Features);
84 }
85 
86 // Tests the calculated SROA cost
87 TEST(InlineCostTest, SROACost) {
88   using namespace llvm;
89 
90   const auto *const IR = R"IR(
91 define void @f_savings(ptr %var) {
92   %load = load i32, ptr %var
93   %inc = add i32 %load, 1
94   store i32 %inc, ptr %var
95   ret void
96 }
97 
98 define void @g_savings(i32) {
99   %var = alloca i32
100   call void @f_savings(ptr %var)
101   ret void
102 }
103 
104 define void @f_losses(ptr %var) {
105   %load = load i32, ptr %var
106   %inc = add i32 %load, 1
107   store i32 %inc, ptr %var
108   call void @prevent_sroa(ptr %var)
109   ret void
110 }
111 
112 define void @g_losses(i32) {
113   %var = alloca i32
114   call void @f_losses(ptr %var)
115   ret void
116 }
117 
118 declare void @prevent_sroa(ptr)
119 )IR";
120 
121   LLVMContext C;
122   SMDiagnostic Err;
123   std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
124   ASSERT_TRUE(M);
125 
126   const int DefaultInstCost = 5;
127   const int DefaultAllocaCost = 0;
128 
129   const char *GName[] = {"g_savings", "g_losses", nullptr};
130   const int Savings[] = {2 * DefaultInstCost + DefaultAllocaCost, 0};
131   const int Losses[] = {0, 2 * DefaultInstCost + DefaultAllocaCost};
132 
133   for (unsigned i = 0; GName[i]; ++i) {
134     auto *G = M->getFunction(GName[i]);
135     ASSERT_TRUE(G);
136 
137     // find the call to f in g
138     CallBase *CB = getCallInFunction(G);
139     ASSERT_TRUE(CB);
140 
141     const auto Features = getInliningCostFeaturesForCall(*CB);
142     ASSERT_TRUE(Features);
143 
144     // Check the predicted SROA cost
145     auto GetFeature = [&](InlineCostFeatureIndex I) {
146       return (*Features)[static_cast<size_t>(I)];
147     };
148     ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings), Savings[i]);
149     ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses), Losses[i]);
150   }
151 }
152 
153 } // namespace
154