1 //===- InlineCostTest.cpp - test for InlineCost ---------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/InlineCost.h" 10 #include "llvm/Analysis/AssumptionCache.h" 11 #include "llvm/Analysis/InlineModelFeatureMaps.h" 12 #include "llvm/Analysis/TargetTransformInfo.h" 13 #include "llvm/AsmParser/Parser.h" 14 #include "llvm/IR/InstIterator.h" 15 #include "llvm/IR/Instructions.h" 16 #include "llvm/IR/LLVMContext.h" 17 #include "llvm/IR/Module.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "gtest/gtest.h" 20 21 namespace { 22 23 using namespace llvm; 24 25 CallBase *getCallInFunction(Function *F) { 26 for (auto &I : instructions(F)) { 27 if (auto *CB = dyn_cast<llvm::CallBase>(&I)) 28 return CB; 29 } 30 return nullptr; 31 } 32 33 std::optional<InlineCostFeatures> getInliningCostFeaturesForCall(CallBase &CB) { 34 ModuleAnalysisManager MAM; 35 FunctionAnalysisManager FAM; 36 FAM.registerPass([&] { return TargetIRAnalysis(); }); 37 FAM.registerPass([&] { return ModuleAnalysisManagerFunctionProxy(MAM); }); 38 FAM.registerPass([&] { return AssumptionAnalysis(); }); 39 MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(FAM); }); 40 41 MAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 42 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 43 44 ModulePassManager MPM; 45 MPM.run(*CB.getModule(), MAM); 46 47 auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { 48 return FAM.getResult<AssumptionAnalysis>(F); 49 }; 50 auto &TIR = FAM.getResult<TargetIRAnalysis>(*CB.getFunction()); 51 52 return getInliningCostFeatures(CB, TIR, GetAssumptionCache); 53 } 54 55 // Tests that we can retrieve the CostFeatures without an error 56 TEST(InlineCostTest, CostFeatures) { 57 const auto *const IR = R"IR( 58 define i32 @f(i32) { 59 ret i32 4 60 } 61 62 define i32 @g(i32) { 63 %2 = call i32 @f(i32 0) 64 ret i32 %2 65 } 66 )IR"; 67 68 LLVMContext C; 69 SMDiagnostic Err; 70 std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C); 71 ASSERT_TRUE(M); 72 73 auto *G = M->getFunction("g"); 74 ASSERT_TRUE(G); 75 76 // find the call to f in g 77 CallBase *CB = getCallInFunction(G); 78 ASSERT_TRUE(CB); 79 80 const auto Features = getInliningCostFeaturesForCall(*CB); 81 82 // Check that the optional is not empty 83 ASSERT_TRUE(Features); 84 } 85 86 // Tests the calculated SROA cost 87 TEST(InlineCostTest, SROACost) { 88 using namespace llvm; 89 90 const auto *const IR = R"IR( 91 define void @f_savings(ptr %var) { 92 %load = load i32, ptr %var 93 %inc = add i32 %load, 1 94 store i32 %inc, ptr %var 95 ret void 96 } 97 98 define void @g_savings(i32) { 99 %var = alloca i32 100 call void @f_savings(ptr %var) 101 ret void 102 } 103 104 define void @f_losses(ptr %var) { 105 %load = load i32, ptr %var 106 %inc = add i32 %load, 1 107 store i32 %inc, ptr %var 108 call void @prevent_sroa(ptr %var) 109 ret void 110 } 111 112 define void @g_losses(i32) { 113 %var = alloca i32 114 call void @f_losses(ptr %var) 115 ret void 116 } 117 118 declare void @prevent_sroa(ptr) 119 )IR"; 120 121 LLVMContext C; 122 SMDiagnostic Err; 123 std::unique_ptr<Module> M = parseAssemblyString(IR, Err, C); 124 ASSERT_TRUE(M); 125 126 const int DefaultInstCost = 5; 127 const int DefaultAllocaCost = 0; 128 129 const char *GName[] = {"g_savings", "g_losses", nullptr}; 130 const int Savings[] = {2 * DefaultInstCost + DefaultAllocaCost, 0}; 131 const int Losses[] = {0, 2 * DefaultInstCost + DefaultAllocaCost}; 132 133 for (unsigned i = 0; GName[i]; ++i) { 134 auto *G = M->getFunction(GName[i]); 135 ASSERT_TRUE(G); 136 137 // find the call to f in g 138 CallBase *CB = getCallInFunction(G); 139 ASSERT_TRUE(CB); 140 141 const auto Features = getInliningCostFeaturesForCall(*CB); 142 ASSERT_TRUE(Features); 143 144 // Check the predicted SROA cost 145 auto GetFeature = [&](InlineCostFeatureIndex I) { 146 return (*Features)[static_cast<size_t>(I)]; 147 }; 148 ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_savings), Savings[i]); 149 ASSERT_EQ(GetFeature(InlineCostFeatureIndex::sroa_losses), Losses[i]); 150 } 151 } 152 153 } // namespace 154