xref: /llvm-project/llvm/unittests/Analysis/InlineCostTest.cpp (revision 36c6632eb43bf67e19c8a6a21981cf66e06389b4)
199f00635SJacob Hegna //===- InlineCostTest.cpp - test for InlineCost ---------------------------===//
299f00635SJacob Hegna //
399f00635SJacob Hegna // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499f00635SJacob Hegna // See https://llvm.org/LICENSE.txt for license information.
599f00635SJacob Hegna // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699f00635SJacob Hegna //
799f00635SJacob Hegna //===----------------------------------------------------------------------===//
899f00635SJacob Hegna 
999f00635SJacob Hegna #include "llvm/Analysis/InlineCost.h"
1071c3a551Sserge-sans-paille #include "llvm/Analysis/AssumptionCache.h"
11cc8a346eSJuan Manuel MARTINEZ CAAMAÑO #include "llvm/Analysis/InlineModelFeatureMaps.h"
1299f00635SJacob Hegna #include "llvm/Analysis/TargetTransformInfo.h"
1399f00635SJacob Hegna #include "llvm/AsmParser/Parser.h"
14cc8a346eSJuan Manuel MARTINEZ CAAMAÑO #include "llvm/IR/InstIterator.h"
1599f00635SJacob Hegna #include "llvm/IR/Instructions.h"
1699f00635SJacob Hegna #include "llvm/IR/LLVMContext.h"
1799f00635SJacob Hegna #include "llvm/IR/Module.h"
18*36c6632eSNikita Popov #include "llvm/IR/PassInstrumentation.h"
1971c3a551Sserge-sans-paille #include "llvm/Support/SourceMgr.h"
2099f00635SJacob Hegna #include "gtest/gtest.h"
2199f00635SJacob Hegna 
2299f00635SJacob Hegna namespace {
2399f00635SJacob Hegna 
2499f00635SJacob Hegna using namespace llvm;
2599f00635SJacob Hegna 
getCallInFunction(Function * F)26cc8a346eSJuan Manuel MARTINEZ CAAMAÑO CallBase *getCallInFunction(Function *F) {
27cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   for (auto &I : instructions(F)) {
28cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     if (auto *CB = dyn_cast<llvm::CallBase>(&I))
29cc8a346eSJuan Manuel MARTINEZ CAAMAÑO       return CB;
30cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   }
31cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   return nullptr;
32cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
33cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
getInliningCostFeaturesForCall(CallBase & CB)34cc8a346eSJuan Manuel MARTINEZ CAAMAÑO std::optional<InlineCostFeatures> getInliningCostFeaturesForCall(CallBase &CB) {
35cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   ModuleAnalysisManager MAM;
36cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   FunctionAnalysisManager FAM;
37cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   FAM.registerPass([&] { return TargetIRAnalysis(); });
38cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
39cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   FAM.registerPass([&] { return AssumptionAnalysis(); });
40cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
41cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
42cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
43cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
44cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
45cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   ModulePassManager MPM;
46cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   MPM.run(*CB.getModule(), MAM);
47cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
48cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
49cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     return FAM.getResult<AssumptionAnalysis>(F);
50cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   };
51cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   auto &TIR = FAM.getResult<TargetIRAnalysis>(*CB.getFunction());
52cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
53cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   return getInliningCostFeatures(CB, TIR, GetAssumptionCache);
54cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
55cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
56cc8a346eSJuan Manuel MARTINEZ CAAMAÑO // Tests that we can retrieve the CostFeatures without an error
TEST(InlineCostTest,CostFeatures)57cc8a346eSJuan Manuel MARTINEZ CAAMAÑO TEST(InlineCostTest, CostFeatures) {
5899f00635SJacob Hegna   const auto *const IR = R"IR(
5999f00635SJacob Hegna define i32 @f(i32) {
6099f00635SJacob Hegna   ret i32 4
6199f00635SJacob Hegna }
6299f00635SJacob Hegna 
6399f00635SJacob Hegna define i32 @g(i32) {
6499f00635SJacob Hegna   %2 = call i32 @f(i32 0)
6599f00635SJacob Hegna   ret i32 %2
6699f00635SJacob Hegna }
6799f00635SJacob Hegna )IR";
6899f00635SJacob Hegna 
6999f00635SJacob Hegna   LLVMContext C;
7099f00635SJacob Hegna   SMDiagnostic Err;
7199f00635SJacob Hegna   std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
7299f00635SJacob Hegna   ASSERT_TRUE(M);
7399f00635SJacob Hegna 
7499f00635SJacob Hegna   auto *G = M->getFunction("g");
7599f00635SJacob Hegna   ASSERT_TRUE(G);
7699f00635SJacob Hegna 
7799f00635SJacob Hegna   // find the call to f in g
78cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   CallBase *CB = getCallInFunction(G);
7999f00635SJacob Hegna   ASSERT_TRUE(CB);
8099f00635SJacob Hegna 
81cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   const auto Features = getInliningCostFeaturesForCall(*CB);
8299f00635SJacob Hegna 
8399f00635SJacob Hegna   // Check that the optional is not empty
8499f00635SJacob Hegna   ASSERT_TRUE(Features);
8599f00635SJacob Hegna }
8699f00635SJacob Hegna 
87cc8a346eSJuan Manuel MARTINEZ CAAMAÑO // Tests the calculated SROA cost
TEST(InlineCostTest,SROACost)88cc8a346eSJuan Manuel MARTINEZ CAAMAÑO TEST(InlineCostTest, SROACost) {
89cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   using namespace llvm;
90cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
91cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   const auto *const IR = R"IR(
92cc8a346eSJuan Manuel MARTINEZ CAAMAÑO define void @f_savings(ptr %var) {
93cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   %load = load i32, ptr %var
94cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   %inc = add i32 %load, 1
95cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   store i32 %inc, ptr %var
96cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   ret void
97cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
98cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
99cc8a346eSJuan Manuel MARTINEZ CAAMAÑO define void @g_savings(i32) {
100cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   %var = alloca i32
101cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   call void @f_savings(ptr %var)
102cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   ret void
103cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
104cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
105cc8a346eSJuan Manuel MARTINEZ CAAMAÑO define void @f_losses(ptr %var) {
106cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   %load = load i32, ptr %var
107cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   %inc = add i32 %load, 1
108cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   store i32 %inc, ptr %var
109cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   call void @prevent_sroa(ptr %var)
110cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   ret void
111cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
112cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
113cc8a346eSJuan Manuel MARTINEZ CAAMAÑO define void @g_losses(i32) {
114cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   %var = alloca i32
115cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   call void @f_losses(ptr %var)
116cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   ret void
117cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
118cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
119cc8a346eSJuan Manuel MARTINEZ CAAMAÑO declare void @prevent_sroa(ptr)
120cc8a346eSJuan Manuel MARTINEZ CAAMAÑO )IR";
121cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
122cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   LLVMContext C;
123cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   SMDiagnostic Err;
124cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
125cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   ASSERT_TRUE(M);
126cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
127cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   const int DefaultInstCost = 5;
128cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   const int DefaultAllocaCost = 0;
129cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
130cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   const char *GName[] = {"g_savings", "g_losses", nullptr};
131cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   const int Savings[] = {2 * DefaultInstCost + DefaultAllocaCost, 0};
132cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   const int Losses[] = {0, 2 * DefaultInstCost + DefaultAllocaCost};
133cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
134cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   for (unsigned i = 0; GName[i]; ++i) {
135cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     auto *G = M->getFunction(GName[i]);
136cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     ASSERT_TRUE(G);
137cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
138cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     // find the call to f in g
139cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     CallBase *CB = getCallInFunction(G);
140cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     ASSERT_TRUE(CB);
141cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
142cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     const auto Features = getInliningCostFeaturesForCall(*CB);
143cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     ASSERT_TRUE(Features);
144cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
145cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     // Check the predicted SROA cost
146cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     auto GetFeature = [&](InlineCostFeatureIndex I) {
147cc8a346eSJuan Manuel MARTINEZ CAAMAÑO       return (*Features)[static_cast<size_t>(I)];
148cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     };
149cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings), Savings[i]);
150cc8a346eSJuan Manuel MARTINEZ CAAMAÑO     ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses), Losses[i]);
151cc8a346eSJuan Manuel MARTINEZ CAAMAÑO   }
152cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
153cc8a346eSJuan Manuel MARTINEZ CAAMAÑO 
15499f00635SJacob Hegna } // namespace
155