1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/AssumptionCache.h" 10 #include "llvm/Analysis/BlockFrequencyInfo.h" 11 #include "llvm/Analysis/BranchProbabilityInfo.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/PostDominators.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/AsmParser/Parser.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "llvm/Transforms/IPO/FunctionSpecialization.h" 20 #include "llvm/Transforms/Utils/SCCPSolver.h" 21 #include "gtest/gtest.h" 22 #include <memory> 23 24 namespace llvm { 25 26 class FunctionSpecializationTest : public testing::Test { 27 protected: 28 LLVMContext Ctx; 29 FunctionAnalysisManager FAM; 30 std::unique_ptr<Module> M; 31 std::unique_ptr<SCCPSolver> Solver; 32 33 FunctionSpecializationTest() { 34 FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 35 FAM.registerPass([&] { return TargetIRAnalysis(); }); 36 FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 37 FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 38 FAM.registerPass([&] { return LoopAnalysis(); }); 39 FAM.registerPass([&] { return AssumptionAnalysis(); }); 40 FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 41 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 42 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 43 } 44 45 Module &parseModule(const char *ModuleString) { 46 SMDiagnostic Err; 47 M = parseAssemblyString(ModuleString, Err, Ctx); 48 EXPECT_TRUE(M); 49 return *M; 50 } 51 52 FunctionSpecializer getSpecializerFor(Function *F) { 53 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 54 return FAM.getResult<TargetLibraryAnalysis>(F); 55 }; 56 auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 57 return FAM.getResult<TargetIRAnalysis>(F); 58 }; 59 auto GetAC = [this](Function &F) -> AssumptionCache & { 60 return FAM.getResult<AssumptionAnalysis>(F); 61 }; 62 auto GetDT = [this](Function &F) -> DominatorTree & { 63 return FAM.getResult<DominatorTreeAnalysis>(F); 64 }; 65 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 66 return FAM.getResult<BlockFrequencyAnalysis>(F); 67 }; 68 69 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 70 71 DominatorTree &DT = GetDT(*F); 72 AssumptionCache &AC = GetAC(*F); 73 Solver->addPredicateInfo(*F, DT, AC); 74 75 Solver->markBlockExecutable(&F->front()); 76 for (Argument &Arg : F->args()) 77 Solver->markOverdefined(&Arg); 78 Solver->solveWhileResolvedUndefsIn(*M); 79 80 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 81 GetAC); 82 } 83 84 Cost getInstCost(Instruction &I) { 85 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 86 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction()); 87 88 return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() * 89 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); 90 } 91 }; 92 93 } // namespace llvm 94 95 using namespace llvm; 96 97 TEST_F(FunctionSpecializationTest, SwitchInst) { 98 const char *ModuleString = R"( 99 define void @foo(i32 %a, i32 %b, i32 %i) { 100 entry: 101 br label %loop 102 loop: 103 switch i32 %i, label %default 104 [ i32 1, label %case1 105 i32 2, label %case2 ] 106 case1: 107 %0 = mul i32 %a, 2 108 %1 = sub i32 6, 5 109 br label %bb1 110 case2: 111 %2 = and i32 %b, 3 112 %3 = sdiv i32 8, 2 113 br label %bb2 114 bb1: 115 %4 = add i32 %0, %b 116 br label %loop 117 bb2: 118 %5 = or i32 %2, %a 119 br label %loop 120 default: 121 ret void 122 } 123 )"; 124 125 Module &M = parseModule(ModuleString); 126 Function *F = M.getFunction("foo"); 127 FunctionSpecializer Specializer = getSpecializerFor(F); 128 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 129 130 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 131 132 auto FuncIter = F->begin(); 133 ++FuncIter; 134 BasicBlock &Case1 = *++FuncIter; 135 BasicBlock &Case2 = *++FuncIter; 136 BasicBlock &BB1 = *++FuncIter; 137 BasicBlock &BB2 = *++FuncIter; 138 139 Instruction &Mul = Case1.front(); 140 Instruction &And = Case2.front(); 141 Instruction &Sdiv = *++Case2.begin(); 142 Instruction &BrBB2 = Case2.back(); 143 Instruction &Add = BB1.front(); 144 Instruction &Or = BB2.front(); 145 Instruction &BrLoop = BB2.back(); 146 147 // mul 148 Cost Ref = getInstCost(Mul); 149 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 150 EXPECT_EQ(Bonus, Ref); 151 EXPECT_TRUE(Bonus > 0); 152 153 // and + or + add 154 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add); 155 Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 156 EXPECT_EQ(Bonus, Ref); 157 EXPECT_TRUE(Bonus > 0); 158 159 // sdiv + br + br 160 Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrLoop); 161 Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor); 162 EXPECT_EQ(Bonus, Ref); 163 EXPECT_TRUE(Bonus > 0); 164 } 165 166 TEST_F(FunctionSpecializationTest, BranchInst) { 167 const char *ModuleString = R"( 168 define void @foo(i32 %a, i32 %b, i1 %cond) { 169 entry: 170 br label %loop 171 loop: 172 br i1 %cond, label %bb0, label %bb2 173 bb0: 174 %0 = mul i32 %a, 2 175 %1 = sub i32 6, 5 176 br label %bb1 177 bb1: 178 %2 = add i32 %0, %b 179 %3 = sdiv i32 8, 2 180 br label %loop 181 bb2: 182 ret void 183 } 184 )"; 185 186 Module &M = parseModule(ModuleString); 187 Function *F = M.getFunction("foo"); 188 FunctionSpecializer Specializer = getSpecializerFor(F); 189 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 190 191 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 192 Constant *False = ConstantInt::getFalse(M.getContext()); 193 194 auto FuncIter = F->begin(); 195 ++FuncIter; 196 BasicBlock &BB0 = *++FuncIter; 197 BasicBlock &BB1 = *++FuncIter; 198 199 Instruction &Mul = BB0.front(); 200 Instruction &Sub = *++BB0.begin(); 201 Instruction &BrBB1 = BB0.back(); 202 Instruction &Add = BB1.front(); 203 Instruction &Sdiv = *++BB1.begin(); 204 Instruction &BrLoop = BB1.back(); 205 206 // mul 207 Cost Ref = getInstCost(Mul); 208 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 209 EXPECT_EQ(Bonus, Ref); 210 EXPECT_TRUE(Bonus > 0); 211 212 // add 213 Ref = getInstCost(Add); 214 Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 215 EXPECT_EQ(Bonus, Ref); 216 EXPECT_TRUE(Bonus > 0); 217 218 // sub + br + sdiv + br 219 Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) + 220 getInstCost(BrLoop); 221 Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor); 222 EXPECT_EQ(Bonus, Ref); 223 EXPECT_TRUE(Bonus > 0); 224 } 225 226 TEST_F(FunctionSpecializationTest, Misc) { 227 const char *ModuleString = R"( 228 @g = constant [2 x i32] zeroinitializer, align 4 229 230 declare i32 @llvm.smax.i32(i32, i32) 231 declare i32 @bar(i32) 232 233 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) { 234 %cmp = icmp eq i8 %a, 10 235 %ext = zext i1 %cmp to i32 236 %sel = select i1 %cond, i32 %ext, i32 1 237 %gep = getelementptr i32, ptr %b, i32 %sel 238 %ld = load i32, ptr %gep 239 %fr = freeze i32 %ld 240 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1) 241 %call = call i32 @bar(i32 %smax) 242 %fr2 = freeze i32 %c 243 %add = add i32 %call, %fr2 244 ret i32 %add 245 } 246 )"; 247 248 Module &M = parseModule(ModuleString); 249 Function *F = M.getFunction("foo"); 250 FunctionSpecializer Specializer = getSpecializerFor(F); 251 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 252 253 GlobalVariable *GV = M.getGlobalVariable("g"); 254 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 255 Constant *True = ConstantInt::getTrue(M.getContext()); 256 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext())); 257 258 auto BlockIter = F->front().begin(); 259 Instruction &Icmp = *BlockIter++; 260 Instruction &Zext = *BlockIter++; 261 Instruction &Select = *BlockIter++; 262 Instruction &Gep = *BlockIter++; 263 Instruction &Load = *BlockIter++; 264 Instruction &Freeze = *BlockIter++; 265 Instruction &Smax = *BlockIter++; 266 267 // icmp + zext 268 Cost Ref = getInstCost(Icmp) + getInstCost(Zext); 269 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 270 EXPECT_EQ(Bonus, Ref); 271 EXPECT_TRUE(Bonus > 0); 272 273 // select 274 Ref = getInstCost(Select); 275 Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor); 276 EXPECT_EQ(Bonus, Ref); 277 EXPECT_TRUE(Bonus > 0); 278 279 // gep + load + freeze + smax 280 Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) + 281 getInstCost(Smax); 282 Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor); 283 EXPECT_EQ(Bonus, Ref); 284 EXPECT_TRUE(Bonus > 0); 285 286 Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor); 287 EXPECT_TRUE(Bonus == 0); 288 } 289