199f00635SJacob Hegna //===- InlineCostTest.cpp - test for InlineCost ---------------------------===//
299f00635SJacob Hegna //
399f00635SJacob Hegna // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
499f00635SJacob Hegna // See https://llvm.org/LICENSE.txt for license information.
599f00635SJacob Hegna // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
699f00635SJacob Hegna //
799f00635SJacob Hegna //===----------------------------------------------------------------------===//
899f00635SJacob Hegna
999f00635SJacob Hegna #include "llvm/Analysis/InlineCost.h"
1071c3a551Sserge-sans-paille #include "llvm/Analysis/AssumptionCache.h"
11cc8a346eSJuan Manuel MARTINEZ CAAMAÑO #include "llvm/Analysis/InlineModelFeatureMaps.h"
1299f00635SJacob Hegna #include "llvm/Analysis/TargetTransformInfo.h"
1399f00635SJacob Hegna #include "llvm/AsmParser/Parser.h"
14cc8a346eSJuan Manuel MARTINEZ CAAMAÑO #include "llvm/IR/InstIterator.h"
1599f00635SJacob Hegna #include "llvm/IR/Instructions.h"
1699f00635SJacob Hegna #include "llvm/IR/LLVMContext.h"
1799f00635SJacob Hegna #include "llvm/IR/Module.h"
18*36c6632eSNikita Popov #include "llvm/IR/PassInstrumentation.h"
1971c3a551Sserge-sans-paille #include "llvm/Support/SourceMgr.h"
2099f00635SJacob Hegna #include "gtest/gtest.h"
2199f00635SJacob Hegna
2299f00635SJacob Hegna namespace {
2399f00635SJacob Hegna
2499f00635SJacob Hegna using namespace llvm;
2599f00635SJacob Hegna
getCallInFunction(Function * F)26cc8a346eSJuan Manuel MARTINEZ CAAMAÑO CallBase *getCallInFunction(Function *F) {
27cc8a346eSJuan Manuel MARTINEZ CAAMAÑO for (auto &I : instructions(F)) {
28cc8a346eSJuan Manuel MARTINEZ CAAMAÑO if (auto *CB = dyn_cast<llvm::CallBase>(&I))
29cc8a346eSJuan Manuel MARTINEZ CAAMAÑO return CB;
30cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
31cc8a346eSJuan Manuel MARTINEZ CAAMAÑO return nullptr;
32cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
33cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
getInliningCostFeaturesForCall(CallBase & CB)34cc8a346eSJuan Manuel MARTINEZ CAAMAÑO std::optional<InlineCostFeatures> getInliningCostFeaturesForCall(CallBase &CB) {
35cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ModuleAnalysisManager MAM;
36cc8a346eSJuan Manuel MARTINEZ CAAMAÑO FunctionAnalysisManager FAM;
37cc8a346eSJuan Manuel MARTINEZ CAAMAÑO FAM.registerPass([&] { return TargetIRAnalysis(); });
38cc8a346eSJuan Manuel MARTINEZ CAAMAÑO FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); });
39cc8a346eSJuan Manuel MARTINEZ CAAMAÑO FAM.registerPass([&] { return AssumptionAnalysis(); });
40cc8a346eSJuan Manuel MARTINEZ CAAMAÑO MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); });
41cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
42cc8a346eSJuan Manuel MARTINEZ CAAMAÑO MAM.registerPass([&] { return PassInstrumentationAnalysis(); });
43cc8a346eSJuan Manuel MARTINEZ CAAMAÑO FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
44cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
45cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ModulePassManager MPM;
46cc8a346eSJuan Manuel MARTINEZ CAAMAÑO MPM.run(*CB.getModule(), MAM);
47cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
48cc8a346eSJuan Manuel MARTINEZ CAAMAÑO auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & {
49cc8a346eSJuan Manuel MARTINEZ CAAMAÑO return FAM.getResult<AssumptionAnalysis>(F);
50cc8a346eSJuan Manuel MARTINEZ CAAMAÑO };
51cc8a346eSJuan Manuel MARTINEZ CAAMAÑO auto &TIR = FAM.getResult<TargetIRAnalysis>(*CB.getFunction());
52cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
53cc8a346eSJuan Manuel MARTINEZ CAAMAÑO return getInliningCostFeatures(CB, TIR, GetAssumptionCache);
54cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
55cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
56cc8a346eSJuan Manuel MARTINEZ CAAMAÑO // Tests that we can retrieve the CostFeatures without an error
TEST(InlineCostTest,CostFeatures)57cc8a346eSJuan Manuel MARTINEZ CAAMAÑO TEST(InlineCostTest, CostFeatures) {
5899f00635SJacob Hegna const auto *const IR = R"IR(
5999f00635SJacob Hegna define i32 @f(i32) {
6099f00635SJacob Hegna ret i32 4
6199f00635SJacob Hegna }
6299f00635SJacob Hegna
6399f00635SJacob Hegna define i32 @g(i32) {
6499f00635SJacob Hegna %2 = call i32 @f(i32 0)
6599f00635SJacob Hegna ret i32 %2
6699f00635SJacob Hegna }
6799f00635SJacob Hegna )IR";
6899f00635SJacob Hegna
6999f00635SJacob Hegna LLVMContext C;
7099f00635SJacob Hegna SMDiagnostic Err;
7199f00635SJacob Hegna std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
7299f00635SJacob Hegna ASSERT_TRUE(M);
7399f00635SJacob Hegna
7499f00635SJacob Hegna auto *G = M->getFunction("g");
7599f00635SJacob Hegna ASSERT_TRUE(G);
7699f00635SJacob Hegna
7799f00635SJacob Hegna // find the call to f in g
78cc8a346eSJuan Manuel MARTINEZ CAAMAÑO CallBase *CB = getCallInFunction(G);
7999f00635SJacob Hegna ASSERT_TRUE(CB);
8099f00635SJacob Hegna
81cc8a346eSJuan Manuel MARTINEZ CAAMAÑO const auto Features = getInliningCostFeaturesForCall(*CB);
8299f00635SJacob Hegna
8399f00635SJacob Hegna // Check that the optional is not empty
8499f00635SJacob Hegna ASSERT_TRUE(Features);
8599f00635SJacob Hegna }
8699f00635SJacob Hegna
87cc8a346eSJuan Manuel MARTINEZ CAAMAÑO // Tests the calculated SROA cost
TEST(InlineCostTest,SROACost)88cc8a346eSJuan Manuel MARTINEZ CAAMAÑO TEST(InlineCostTest, SROACost) {
89cc8a346eSJuan Manuel MARTINEZ CAAMAÑO using namespace llvm;
90cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
91cc8a346eSJuan Manuel MARTINEZ CAAMAÑO const auto *const IR = R"IR(
92cc8a346eSJuan Manuel MARTINEZ CAAMAÑO define void @f_savings(ptr %var) {
93cc8a346eSJuan Manuel MARTINEZ CAAMAÑO %load = load i32, ptr %var
94cc8a346eSJuan Manuel MARTINEZ CAAMAÑO %inc = add i32 %load, 1
95cc8a346eSJuan Manuel MARTINEZ CAAMAÑO store i32 %inc, ptr %var
96cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ret void
97cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
98cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
99cc8a346eSJuan Manuel MARTINEZ CAAMAÑO define void @g_savings(i32) {
100cc8a346eSJuan Manuel MARTINEZ CAAMAÑO %var = alloca i32
101cc8a346eSJuan Manuel MARTINEZ CAAMAÑO call void @f_savings(ptr %var)
102cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ret void
103cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
104cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
105cc8a346eSJuan Manuel MARTINEZ CAAMAÑO define void @f_losses(ptr %var) {
106cc8a346eSJuan Manuel MARTINEZ CAAMAÑO %load = load i32, ptr %var
107cc8a346eSJuan Manuel MARTINEZ CAAMAÑO %inc = add i32 %load, 1
108cc8a346eSJuan Manuel MARTINEZ CAAMAÑO store i32 %inc, ptr %var
109cc8a346eSJuan Manuel MARTINEZ CAAMAÑO call void @prevent_sroa(ptr %var)
110cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ret void
111cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
112cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
113cc8a346eSJuan Manuel MARTINEZ CAAMAÑO define void @g_losses(i32) {
114cc8a346eSJuan Manuel MARTINEZ CAAMAÑO %var = alloca i32
115cc8a346eSJuan Manuel MARTINEZ CAAMAÑO call void @f_losses(ptr %var)
116cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ret void
117cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
118cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
119cc8a346eSJuan Manuel MARTINEZ CAAMAÑO declare void @prevent_sroa(ptr)
120cc8a346eSJuan Manuel MARTINEZ CAAMAÑO )IR";
121cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
122cc8a346eSJuan Manuel MARTINEZ CAAMAÑO LLVMContext C;
123cc8a346eSJuan Manuel MARTINEZ CAAMAÑO SMDiagnostic Err;
124cc8a346eSJuan Manuel MARTINEZ CAAMAÑO std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C);
125cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ASSERT_TRUE(M);
126cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
127cc8a346eSJuan Manuel MARTINEZ CAAMAÑO const int DefaultInstCost = 5;
128cc8a346eSJuan Manuel MARTINEZ CAAMAÑO const int DefaultAllocaCost = 0;
129cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
130cc8a346eSJuan Manuel MARTINEZ CAAMAÑO const char *GName[] = {"g_savings", "g_losses", nullptr};
131cc8a346eSJuan Manuel MARTINEZ CAAMAÑO const int Savings[] = {2 * DefaultInstCost + DefaultAllocaCost, 0};
132cc8a346eSJuan Manuel MARTINEZ CAAMAÑO const int Losses[] = {0, 2 * DefaultInstCost + DefaultAllocaCost};
133cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
134cc8a346eSJuan Manuel MARTINEZ CAAMAÑO for (unsigned i = 0; GName[i]; ++i) {
135cc8a346eSJuan Manuel MARTINEZ CAAMAÑO auto *G = M->getFunction(GName[i]);
136cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ASSERT_TRUE(G);
137cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
138cc8a346eSJuan Manuel MARTINEZ CAAMAÑO // find the call to f in g
139cc8a346eSJuan Manuel MARTINEZ CAAMAÑO CallBase *CB = getCallInFunction(G);
140cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ASSERT_TRUE(CB);
141cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
142cc8a346eSJuan Manuel MARTINEZ CAAMAÑO const auto Features = getInliningCostFeaturesForCall(*CB);
143cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ASSERT_TRUE(Features);
144cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
145cc8a346eSJuan Manuel MARTINEZ CAAMAÑO // Check the predicted SROA cost
146cc8a346eSJuan Manuel MARTINEZ CAAMAÑO auto GetFeature = [&](InlineCostFeatureIndex I) {
147cc8a346eSJuan Manuel MARTINEZ CAAMAÑO return (*Features)[static_cast<size_t>(I)];
148cc8a346eSJuan Manuel MARTINEZ CAAMAÑO };
149cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings), Savings[i]);
150cc8a346eSJuan Manuel MARTINEZ CAAMAÑO ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses), Losses[i]);
151cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
152cc8a346eSJuan Manuel MARTINEZ CAAMAÑO }
153cc8a346eSJuan Manuel MARTINEZ CAAMAÑO
15499f00635SJacob Hegna } // namespace
155