1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/AssumptionCache.h" 10 #include "llvm/Analysis/BlockFrequencyInfo.h" 11 #include "llvm/Analysis/BranchProbabilityInfo.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/PostDominators.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/AsmParser/Parser.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "llvm/Transforms/IPO/FunctionSpecialization.h" 20 #include "llvm/Transforms/Utils/SCCPSolver.h" 21 #include "gtest/gtest.h" 22 #include <memory> 23 24 namespace llvm { 25 26 class FunctionSpecializationTest : public testing::Test { 27 protected: 28 LLVMContext Ctx; 29 FunctionAnalysisManager FAM; 30 std::unique_ptr<Module> M; 31 std::unique_ptr<SCCPSolver> Solver; 32 33 FunctionSpecializationTest() { 34 FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 35 FAM.registerPass([&] { return TargetIRAnalysis(); }); 36 FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 37 FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 38 FAM.registerPass([&] { return LoopAnalysis(); }); 39 FAM.registerPass([&] { return AssumptionAnalysis(); }); 40 FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 41 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 42 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 43 } 44 45 Module &parseModule(const char *ModuleString) { 46 SMDiagnostic Err; 47 M = parseAssemblyString(ModuleString, Err, Ctx); 48 EXPECT_TRUE(M); 49 return *M; 50 } 51 52 FunctionSpecializer getSpecializerFor(Function *F) { 53 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 54 return FAM.getResult<TargetLibraryAnalysis>(F); 55 }; 56 auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 57 return FAM.getResult<TargetIRAnalysis>(F); 58 }; 59 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 60 return FAM.getResult<BlockFrequencyAnalysis>(F); 61 }; 62 auto GetAC = [this](Function &F) -> AssumptionCache & { 63 return FAM.getResult<AssumptionAnalysis>(F); 64 }; 65 auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn { 66 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 67 return { std::make_unique<PredicateInfo>(F, DT, 68 FAM.getResult<AssumptionAnalysis>(F)), 69 &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F) }; 70 }; 71 72 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 73 74 Solver->addAnalysis(*F, GetAnalysis(*F)); 75 Solver->markBlockExecutable(&F->front()); 76 for (Argument &Arg : F->args()) 77 Solver->markOverdefined(&Arg); 78 Solver->solveWhileResolvedUndefsIn(*M); 79 80 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 81 GetAC); 82 } 83 84 Cost getInstCost(Instruction &I) { 85 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 86 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction()); 87 88 uint64_t Weight = FunctionSpecializer::getBlockFreqMultiplier() * 89 BFI.getBlockFreq(I.getParent()).getFrequency() / 90 BFI.getEntryFreq(); 91 return Weight * 92 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); 93 } 94 }; 95 96 } // namespace llvm 97 98 using namespace llvm; 99 100 TEST_F(FunctionSpecializationTest, SwitchInst) { 101 const char *ModuleString = R"( 102 define void @foo(i32 %a, i32 %b, i32 %i) { 103 entry: 104 switch i32 %i, label %default 105 [ i32 1, label %case1 106 i32 2, label %case2 ] 107 case1: 108 %0 = mul i32 %a, 2 109 %1 = sub i32 6, 5 110 br label %bb1 111 case2: 112 %2 = and i32 %b, 3 113 %3 = sdiv i32 8, 2 114 br label %bb2 115 bb1: 116 %4 = add i32 %0, %b 117 br label %default 118 bb2: 119 %5 = or i32 %2, %a 120 br label %default 121 default: 122 ret void 123 } 124 )"; 125 126 Module &M = parseModule(ModuleString); 127 Function *F = M.getFunction("foo"); 128 FunctionSpecializer Specializer = getSpecializerFor(F); 129 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 130 131 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 132 133 auto FuncIter = F->begin(); 134 BasicBlock &Case1 = *++FuncIter; 135 BasicBlock &Case2 = *++FuncIter; 136 BasicBlock &BB1 = *++FuncIter; 137 BasicBlock &BB2 = *++FuncIter; 138 139 Instruction &Mul = Case1.front(); 140 Instruction &And = Case2.front(); 141 Instruction &Sdiv = *++Case2.begin(); 142 Instruction &BrBB2 = Case2.back(); 143 Instruction &Add = BB1.front(); 144 Instruction &Or = BB2.front(); 145 Instruction &BrDefault = BB2.back(); 146 147 // mul 148 Cost Ref = getInstCost(Mul); 149 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 150 EXPECT_EQ(Bonus, Ref); 151 152 // and + or + add 153 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add); 154 Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 155 EXPECT_EQ(Bonus, Ref); 156 157 // sdiv + br + br 158 Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrDefault); 159 Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor); 160 EXPECT_EQ(Bonus, Ref); 161 } 162 163 TEST_F(FunctionSpecializationTest, BranchInst) { 164 const char *ModuleString = R"( 165 define void @foo(i32 %a, i32 %b, i1 %cond) { 166 entry: 167 br i1 %cond, label %bb0, label %bb2 168 bb0: 169 %0 = mul i32 %a, 2 170 %1 = sub i32 6, 5 171 br label %bb1 172 bb1: 173 %2 = add i32 %0, %b 174 %3 = sdiv i32 8, 2 175 br label %bb2 176 bb2: 177 ret void 178 } 179 )"; 180 181 Module &M = parseModule(ModuleString); 182 Function *F = M.getFunction("foo"); 183 FunctionSpecializer Specializer = getSpecializerFor(F); 184 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 185 186 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 187 Constant *False = ConstantInt::getFalse(M.getContext()); 188 189 auto FuncIter = F->begin(); 190 BasicBlock &BB0 = *++FuncIter; 191 BasicBlock &BB1 = *++FuncIter; 192 193 Instruction &Mul = BB0.front(); 194 Instruction &Sub = *++BB0.begin(); 195 Instruction &BrBB1 = BB0.back(); 196 Instruction &Add = BB1.front(); 197 Instruction &Sdiv = *++BB1.begin(); 198 Instruction &BrBB2 = BB1.back(); 199 200 // mul 201 Cost Ref = getInstCost(Mul); 202 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 203 EXPECT_EQ(Bonus, Ref); 204 205 // add 206 Ref = getInstCost(Add); 207 Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 208 EXPECT_EQ(Bonus, Ref); 209 210 // sub + br + sdiv + br 211 Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) + 212 getInstCost(BrBB2); 213 Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor); 214 EXPECT_EQ(Bonus, Ref); 215 } 216 217 TEST_F(FunctionSpecializationTest, Misc) { 218 const char *ModuleString = R"( 219 @g = constant [2 x i32] zeroinitializer, align 4 220 221 define i32 @foo(i8 %a, i1 %cond, ptr %b) { 222 %cmp = icmp eq i8 %a, 10 223 %ext = zext i1 %cmp to i32 224 %sel = select i1 %cond, i32 %ext, i32 1 225 %gep = getelementptr i32, ptr %b, i32 %sel 226 %ld = load i32, ptr %gep 227 ret i32 %ld 228 } 229 )"; 230 231 Module &M = parseModule(ModuleString); 232 Function *F = M.getFunction("foo"); 233 FunctionSpecializer Specializer = getSpecializerFor(F); 234 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 235 236 GlobalVariable *GV = M.getGlobalVariable("g"); 237 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 238 Constant *True = ConstantInt::getTrue(M.getContext()); 239 240 auto BlockIter = F->front().begin(); 241 Instruction &Icmp = *BlockIter++; 242 Instruction &Zext = *BlockIter++; 243 Instruction &Select = *BlockIter++; 244 Instruction &Gep = *BlockIter++; 245 Instruction &Load = *BlockIter++; 246 247 // icmp + zext 248 Cost Ref = getInstCost(Icmp) + getInstCost(Zext); 249 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 250 EXPECT_EQ(Bonus, Ref); 251 252 // select 253 Ref = getInstCost(Select); 254 Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor); 255 EXPECT_EQ(Bonus, Ref); 256 257 // gep + load 258 Ref = getInstCost(Gep) + getInstCost(Load); 259 Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor); 260 EXPECT_EQ(Bonus, Ref); 261 } 262