1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/AssumptionCache.h" 10 #include "llvm/Analysis/BlockFrequencyInfo.h" 11 #include "llvm/Analysis/BranchProbabilityInfo.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/PostDominators.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/AsmParser/Parser.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "llvm/Transforms/IPO/FunctionSpecialization.h" 20 #include "llvm/Transforms/Utils/SCCPSolver.h" 21 #include "gtest/gtest.h" 22 #include <memory> 23 24 namespace llvm { 25 26 class FunctionSpecializationTest : public testing::Test { 27 protected: 28 LLVMContext Ctx; 29 FunctionAnalysisManager FAM; 30 std::unique_ptr<Module> M; 31 std::unique_ptr<SCCPSolver> Solver; 32 33 FunctionSpecializationTest() { 34 FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 35 FAM.registerPass([&] { return TargetIRAnalysis(); }); 36 FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 37 FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 38 FAM.registerPass([&] { return LoopAnalysis(); }); 39 FAM.registerPass([&] { return AssumptionAnalysis(); }); 40 FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 41 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 42 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 43 } 44 45 Module &parseModule(const char *ModuleString) { 46 SMDiagnostic Err; 47 M = parseAssemblyString(ModuleString, Err, Ctx); 48 EXPECT_TRUE(M); 49 return *M; 50 } 51 52 FunctionSpecializer getSpecializerFor(Function *F) { 53 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 54 return FAM.getResult<TargetLibraryAnalysis>(F); 55 }; 56 auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 57 return FAM.getResult<TargetIRAnalysis>(F); 58 }; 59 auto GetAC = [this](Function &F) -> AssumptionCache & { 60 return FAM.getResult<AssumptionAnalysis>(F); 61 }; 62 auto GetDT = [this](Function &F) -> DominatorTree & { 63 return FAM.getResult<DominatorTreeAnalysis>(F); 64 }; 65 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 66 return FAM.getResult<BlockFrequencyAnalysis>(F); 67 }; 68 69 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 70 71 DominatorTree &DT = GetDT(*F); 72 AssumptionCache &AC = GetAC(*F); 73 Solver->addPredicateInfo(*F, DT, AC); 74 75 Solver->markBlockExecutable(&F->front()); 76 for (Argument &Arg : F->args()) 77 Solver->markOverdefined(&Arg); 78 Solver->solveWhileResolvedUndefsIn(*M); 79 80 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 81 GetAC); 82 } 83 84 Bonus getInstCost(Instruction &I, bool SizeOnly = false) { 85 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 86 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction()); 87 88 Cost CodeSize = 89 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); 90 91 Cost Latency = SizeOnly ? 0 : 92 BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() * 93 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency); 94 95 return {CodeSize, Latency}; 96 } 97 }; 98 99 } // namespace llvm 100 101 using namespace llvm; 102 103 TEST_F(FunctionSpecializationTest, SwitchInst) { 104 const char *ModuleString = R"( 105 define void @foo(i32 %a, i32 %b, i32 %i) { 106 entry: 107 br label %loop 108 loop: 109 switch i32 %i, label %default 110 [ i32 1, label %case1 111 i32 2, label %case2 ] 112 case1: 113 %0 = mul i32 %a, 2 114 %1 = sub i32 6, 5 115 br label %bb1 116 case2: 117 %2 = and i32 %b, 3 118 %3 = sdiv i32 8, 2 119 br label %bb2 120 bb1: 121 %4 = add i32 %0, %b 122 br label %loop 123 bb2: 124 %5 = or i32 %2, %a 125 br label %loop 126 default: 127 ret void 128 } 129 )"; 130 131 Module &M = parseModule(ModuleString); 132 Function *F = M.getFunction("foo"); 133 FunctionSpecializer Specializer = getSpecializerFor(F); 134 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 135 136 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 137 138 auto FuncIter = F->begin(); 139 BasicBlock &Loop = *++FuncIter; 140 BasicBlock &Case1 = *++FuncIter; 141 BasicBlock &Case2 = *++FuncIter; 142 BasicBlock &BB1 = *++FuncIter; 143 BasicBlock &BB2 = *++FuncIter; 144 145 Instruction &Switch = Loop.front(); 146 Instruction &Mul = Case1.front(); 147 Instruction &And = Case2.front(); 148 Instruction &Sdiv = *++Case2.begin(); 149 Instruction &BrBB2 = Case2.back(); 150 Instruction &Add = BB1.front(); 151 Instruction &Or = BB2.front(); 152 Instruction &BrLoop = BB2.back(); 153 154 // mul 155 Bonus Ref = getInstCost(Mul); 156 Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 157 EXPECT_EQ(Test, Ref); 158 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 159 160 // and + or + add 161 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add); 162 Test = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 163 EXPECT_EQ(Test, Ref); 164 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 165 166 // switch + sdiv + br + br 167 Ref = getInstCost(Switch) + 168 getInstCost(Sdiv, /*SizeOnly =*/ true) + 169 getInstCost(BrBB2, /*SizeOnly =*/ true) + 170 getInstCost(BrLoop, /*SizeOnly =*/ true); 171 Test = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor); 172 EXPECT_EQ(Test, Ref); 173 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 174 } 175 176 TEST_F(FunctionSpecializationTest, BranchInst) { 177 const char *ModuleString = R"( 178 define void @foo(i32 %a, i32 %b, i1 %cond) { 179 entry: 180 br label %loop 181 loop: 182 br i1 %cond, label %bb0, label %bb2 183 bb0: 184 %0 = mul i32 %a, 2 185 %1 = sub i32 6, 5 186 br label %bb1 187 bb1: 188 %2 = add i32 %0, %b 189 %3 = sdiv i32 8, 2 190 br label %loop 191 bb2: 192 ret void 193 } 194 )"; 195 196 Module &M = parseModule(ModuleString); 197 Function *F = M.getFunction("foo"); 198 FunctionSpecializer Specializer = getSpecializerFor(F); 199 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 200 201 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 202 Constant *False = ConstantInt::getFalse(M.getContext()); 203 204 auto FuncIter = F->begin(); 205 BasicBlock &Loop = *++FuncIter; 206 BasicBlock &BB0 = *++FuncIter; 207 BasicBlock &BB1 = *++FuncIter; 208 209 Instruction &Branch = Loop.front(); 210 Instruction &Mul = BB0.front(); 211 Instruction &Sub = *++BB0.begin(); 212 Instruction &BrBB1 = BB0.back(); 213 Instruction &Add = BB1.front(); 214 Instruction &Sdiv = *++BB1.begin(); 215 Instruction &BrLoop = BB1.back(); 216 217 // mul 218 Bonus Ref = getInstCost(Mul); 219 Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 220 EXPECT_EQ(Test, Ref); 221 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 222 223 // add 224 Ref = getInstCost(Add); 225 Test = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 226 EXPECT_EQ(Test, Ref); 227 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 228 229 // branch + sub + br + sdiv + br 230 Ref = getInstCost(Branch) + 231 getInstCost(Sub, /*SizeOnly =*/ true) + 232 getInstCost(BrBB1, /*SizeOnly =*/ true) + 233 getInstCost(Sdiv, /*SizeOnly =*/ true) + 234 getInstCost(BrLoop, /*SizeOnly =*/ true); 235 Test = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor); 236 EXPECT_EQ(Test, Ref); 237 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 238 } 239 240 TEST_F(FunctionSpecializationTest, Misc) { 241 const char *ModuleString = R"( 242 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] } 243 @g = constant %struct_t zeroinitializer, align 16 244 245 declare i32 @llvm.smax.i32(i32, i32) 246 declare i32 @bar(i32) 247 248 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) { 249 %cmp = icmp eq i8 %a, 10 250 %ext = zext i1 %cmp to i64 251 %sel = select i1 %cond, i64 %ext, i64 1 252 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4 253 %ld = load i32, ptr %gep 254 %fr = freeze i32 %ld 255 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1) 256 %call = call i32 @bar(i32 %smax) 257 %fr2 = freeze i32 %c 258 %add = add i32 %call, %fr2 259 ret i32 %add 260 } 261 )"; 262 263 Module &M = parseModule(ModuleString); 264 Function *F = M.getFunction("foo"); 265 FunctionSpecializer Specializer = getSpecializerFor(F); 266 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 267 268 GlobalVariable *GV = M.getGlobalVariable("g"); 269 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 270 Constant *True = ConstantInt::getTrue(M.getContext()); 271 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext())); 272 273 auto BlockIter = F->front().begin(); 274 Instruction &Icmp = *BlockIter++; 275 Instruction &Zext = *BlockIter++; 276 Instruction &Select = *BlockIter++; 277 Instruction &Gep = *BlockIter++; 278 Instruction &Load = *BlockIter++; 279 Instruction &Freeze = *BlockIter++; 280 Instruction &Smax = *BlockIter++; 281 282 // icmp + zext 283 Bonus Ref = getInstCost(Icmp) + getInstCost(Zext); 284 Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 285 EXPECT_EQ(Test, Ref); 286 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 287 288 // select 289 Ref = getInstCost(Select); 290 Test = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor); 291 EXPECT_EQ(Test, Ref); 292 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 293 294 // gep + load + freeze + smax 295 Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) + 296 getInstCost(Smax); 297 Test = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor); 298 EXPECT_EQ(Test, Ref); 299 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 300 301 Test = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor); 302 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 303 } 304 305 TEST_F(FunctionSpecializationTest, PhiNode) { 306 const char *ModuleString = R"( 307 define void @foo(i32 %a, i32 %b, i32 %i) { 308 entry: 309 br label %loop 310 loop: 311 switch i32 %i, label %default 312 [ i32 1, label %case1 313 i32 2, label %case2 ] 314 case1: 315 %0 = add i32 %a, 1 316 br label %bb 317 case2: 318 %1 = sub i32 %b, 1 319 br label %bb 320 bb: 321 %2 = phi i32 [ %0, %case1 ], [ %1, %case2 ], [ %2, %bb ] 322 %3 = icmp eq i32 %2, 2 323 br i1 %3, label %bb, label %loop 324 default: 325 ret void 326 } 327 )"; 328 329 Module &M = parseModule(ModuleString); 330 Function *F = M.getFunction("foo"); 331 FunctionSpecializer Specializer = getSpecializerFor(F); 332 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 333 334 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 335 336 auto FuncIter = F->begin(); 337 for (int I = 0; I < 4; ++I) 338 ++FuncIter; 339 340 BasicBlock &BB = *FuncIter; 341 342 Instruction &Phi = BB.front(); 343 Instruction &Icmp = *++BB.begin(); 344 Instruction &Branch = BB.back(); 345 346 Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor) + 347 Specializer.getSpecializationBonus(F->getArg(1), One, Visitor) + 348 Specializer.getSpecializationBonus(F->getArg(2), One, Visitor); 349 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 350 351 // phi + icmp + branch 352 Bonus Ref = getInstCost(Phi) + getInstCost(Icmp) + getInstCost(Branch); 353 Test = Visitor.getBonusFromPendingPHIs(); 354 EXPECT_EQ(Test, Ref); 355 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 356 } 357 358