1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/AssumptionCache.h" 10 #include "llvm/Analysis/BlockFrequencyInfo.h" 11 #include "llvm/Analysis/BranchProbabilityInfo.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/PostDominators.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/AsmParser/Parser.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "llvm/Transforms/IPO/FunctionSpecialization.h" 20 #include "llvm/Transforms/Utils/SCCPSolver.h" 21 #include "gtest/gtest.h" 22 #include <memory> 23 24 namespace llvm { 25 26 static void removeSSACopy(Function &F) { 27 for (BasicBlock &BB : F) { 28 for (Instruction &Inst : llvm::make_early_inc_range(BB)) { 29 if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { 30 if (II->getIntrinsicID() != Intrinsic::ssa_copy) 31 continue; 32 Inst.replaceAllUsesWith(II->getOperand(0)); 33 Inst.eraseFromParent(); 34 } 35 } 36 } 37 } 38 39 class FunctionSpecializationTest : public testing::Test { 40 protected: 41 LLVMContext Ctx; 42 FunctionAnalysisManager FAM; 43 std::unique_ptr<Module> M; 44 std::unique_ptr<SCCPSolver> Solver; 45 46 FunctionSpecializationTest() { 47 FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 48 FAM.registerPass([&] { return TargetIRAnalysis(); }); 49 FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 50 FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 51 FAM.registerPass([&] { return LoopAnalysis(); }); 52 FAM.registerPass([&] { return AssumptionAnalysis(); }); 53 FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 54 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 55 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 56 } 57 58 Module &parseModule(const char *ModuleString) { 59 SMDiagnostic Err; 60 M = parseAssemblyString(ModuleString, Err, Ctx); 61 EXPECT_TRUE(M); 62 return *M; 63 } 64 65 FunctionSpecializer getSpecializerFor(Function *F) { 66 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 67 return FAM.getResult<TargetLibraryAnalysis>(F); 68 }; 69 auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 70 return FAM.getResult<TargetIRAnalysis>(F); 71 }; 72 auto GetAC = [this](Function &F) -> AssumptionCache & { 73 return FAM.getResult<AssumptionAnalysis>(F); 74 }; 75 auto GetDT = [this](Function &F) -> DominatorTree & { 76 return FAM.getResult<DominatorTreeAnalysis>(F); 77 }; 78 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 79 return FAM.getResult<BlockFrequencyAnalysis>(F); 80 }; 81 82 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 83 84 DominatorTree &DT = GetDT(*F); 85 AssumptionCache &AC = GetAC(*F); 86 Solver->addPredicateInfo(*F, DT, AC); 87 88 Solver->markBlockExecutable(&F->front()); 89 for (Argument &Arg : F->args()) 90 Solver->markOverdefined(&Arg); 91 Solver->solveWhileResolvedUndefsIn(*M); 92 93 removeSSACopy(*F); 94 95 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 96 GetAC); 97 } 98 99 Bonus getInstCost(Instruction &I, bool SizeOnly = false) { 100 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 101 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction()); 102 103 Cost CodeSize = 104 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); 105 106 Cost Latency = 107 SizeOnly 108 ? 0 109 : BFI.getBlockFreq(I.getParent()).getFrequency() / 110 BFI.getEntryFreq().getFrequency() * 111 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency); 112 113 return {CodeSize, Latency}; 114 } 115 }; 116 117 } // namespace llvm 118 119 using namespace llvm; 120 121 TEST_F(FunctionSpecializationTest, SwitchInst) { 122 const char *ModuleString = R"( 123 define void @foo(i32 %a, i32 %b, i32 %i) { 124 entry: 125 br label %loop 126 loop: 127 switch i32 %i, label %default 128 [ i32 1, label %case1 129 i32 2, label %case2 ] 130 case1: 131 %0 = mul i32 %a, 2 132 %1 = sub i32 6, 5 133 br label %bb1 134 case2: 135 %2 = and i32 %b, 3 136 %3 = sdiv i32 8, 2 137 br label %bb2 138 bb1: 139 %4 = add i32 %0, %b 140 br label %loop 141 bb2: 142 %5 = or i32 %2, %a 143 br label %loop 144 default: 145 ret void 146 } 147 )"; 148 149 Module &M = parseModule(ModuleString); 150 Function *F = M.getFunction("foo"); 151 FunctionSpecializer Specializer = getSpecializerFor(F); 152 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 153 154 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 155 156 auto FuncIter = F->begin(); 157 BasicBlock &Loop = *++FuncIter; 158 BasicBlock &Case1 = *++FuncIter; 159 BasicBlock &Case2 = *++FuncIter; 160 BasicBlock &BB1 = *++FuncIter; 161 BasicBlock &BB2 = *++FuncIter; 162 163 Instruction &Switch = Loop.front(); 164 Instruction &Mul = Case1.front(); 165 Instruction &And = Case2.front(); 166 Instruction &Sdiv = *++Case2.begin(); 167 Instruction &BrBB2 = Case2.back(); 168 Instruction &Add = BB1.front(); 169 Instruction &Or = BB2.front(); 170 Instruction &BrLoop = BB2.back(); 171 172 // mul 173 Bonus Ref = getInstCost(Mul); 174 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 175 EXPECT_EQ(Test, Ref); 176 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 177 178 // and + or + add 179 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add); 180 Test = Visitor.getSpecializationBonus(F->getArg(1), One); 181 EXPECT_EQ(Test, Ref); 182 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 183 184 // switch + sdiv + br + br 185 Ref = getInstCost(Switch) + 186 getInstCost(Sdiv, /*SizeOnly =*/ true) + 187 getInstCost(BrBB2, /*SizeOnly =*/ true) + 188 getInstCost(BrLoop, /*SizeOnly =*/ true); 189 Test = Visitor.getSpecializationBonus(F->getArg(2), One); 190 EXPECT_EQ(Test, Ref); 191 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 192 } 193 194 TEST_F(FunctionSpecializationTest, BranchInst) { 195 const char *ModuleString = R"( 196 define void @foo(i32 %a, i32 %b, i1 %cond) { 197 entry: 198 br label %loop 199 loop: 200 br i1 %cond, label %bb0, label %bb3 201 bb0: 202 %0 = mul i32 %a, 2 203 %1 = sub i32 6, 5 204 br i1 %cond, label %bb1, label %bb2 205 bb1: 206 %2 = add i32 %0, %b 207 %3 = sdiv i32 8, 2 208 br label %bb2 209 bb2: 210 br label %loop 211 bb3: 212 ret void 213 } 214 )"; 215 216 Module &M = parseModule(ModuleString); 217 Function *F = M.getFunction("foo"); 218 FunctionSpecializer Specializer = getSpecializerFor(F); 219 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 220 221 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 222 Constant *False = ConstantInt::getFalse(M.getContext()); 223 224 auto FuncIter = F->begin(); 225 BasicBlock &Loop = *++FuncIter; 226 BasicBlock &BB0 = *++FuncIter; 227 BasicBlock &BB1 = *++FuncIter; 228 BasicBlock &BB2 = *++FuncIter; 229 230 Instruction &Branch = Loop.front(); 231 Instruction &Mul = BB0.front(); 232 Instruction &Sub = *++BB0.begin(); 233 Instruction &BrBB1BB2 = BB0.back(); 234 Instruction &Add = BB1.front(); 235 Instruction &Sdiv = *++BB1.begin(); 236 Instruction &BrBB2 = BB1.back(); 237 Instruction &BrLoop = BB2.front(); 238 239 // mul 240 Bonus Ref = getInstCost(Mul); 241 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 242 EXPECT_EQ(Test, Ref); 243 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 244 245 // add 246 Ref = getInstCost(Add); 247 Test = Visitor.getSpecializationBonus(F->getArg(1), One); 248 EXPECT_EQ(Test, Ref); 249 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 250 251 // branch + sub + br + sdiv + br 252 Ref = getInstCost(Branch) + 253 getInstCost(Sub, /*SizeOnly =*/ true) + 254 getInstCost(BrBB1BB2) + 255 getInstCost(Sdiv, /*SizeOnly =*/ true) + 256 getInstCost(BrBB2, /*SizeOnly =*/ true) + 257 getInstCost(BrLoop, /*SizeOnly =*/ true); 258 Test = Visitor.getSpecializationBonus(F->getArg(2), False); 259 EXPECT_EQ(Test, Ref); 260 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 261 } 262 263 TEST_F(FunctionSpecializationTest, Misc) { 264 const char *ModuleString = R"( 265 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] } 266 @g = constant %struct_t zeroinitializer, align 16 267 268 declare i32 @llvm.smax.i32(i32, i32) 269 declare i32 @bar(i32) 270 271 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) { 272 %cmp = icmp eq i8 %a, 10 273 %ext = zext i1 %cmp to i64 274 %sel = select i1 %cond, i64 %ext, i64 1 275 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4 276 %ld = load i32, ptr %gep 277 %fr = freeze i32 %ld 278 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1) 279 %call = call i32 @bar(i32 %smax) 280 %fr2 = freeze i32 %c 281 %add = add i32 %call, %fr2 282 ret i32 %add 283 } 284 )"; 285 286 Module &M = parseModule(ModuleString); 287 Function *F = M.getFunction("foo"); 288 FunctionSpecializer Specializer = getSpecializerFor(F); 289 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 290 291 GlobalVariable *GV = M.getGlobalVariable("g"); 292 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 293 Constant *True = ConstantInt::getTrue(M.getContext()); 294 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext())); 295 296 auto BlockIter = F->front().begin(); 297 Instruction &Icmp = *BlockIter++; 298 Instruction &Zext = *BlockIter++; 299 Instruction &Select = *BlockIter++; 300 Instruction &Gep = *BlockIter++; 301 Instruction &Load = *BlockIter++; 302 Instruction &Freeze = *BlockIter++; 303 Instruction &Smax = *BlockIter++; 304 305 // icmp + zext 306 Bonus Ref = getInstCost(Icmp) + getInstCost(Zext); 307 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 308 EXPECT_EQ(Test, Ref); 309 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 310 311 // select 312 Ref = getInstCost(Select); 313 Test = Visitor.getSpecializationBonus(F->getArg(1), True); 314 EXPECT_EQ(Test, Ref); 315 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 316 317 // gep + load + freeze + smax 318 Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) + 319 getInstCost(Smax); 320 Test = Visitor.getSpecializationBonus(F->getArg(2), GV); 321 EXPECT_EQ(Test, Ref); 322 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 323 324 Test = Visitor.getSpecializationBonus(F->getArg(3), Undef); 325 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 326 } 327 328 TEST_F(FunctionSpecializationTest, PhiNode) { 329 const char *ModuleString = R"( 330 define void @foo(i32 %a, i32 %b, i32 %i) { 331 entry: 332 br label %loop 333 loop: 334 %0 = phi i32 [ %a, %entry ], [ %3, %bb ] 335 switch i32 %i, label %default 336 [ i32 1, label %case1 337 i32 2, label %case2 ] 338 case1: 339 %1 = add i32 %0, 1 340 br label %bb 341 case2: 342 %2 = phi i32 [ %a, %entry ], [ %0, %loop ] 343 br label %bb 344 bb: 345 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ] 346 %4 = icmp eq i32 %3, 1 347 br i1 %4, label %bb, label %loop 348 default: 349 ret void 350 } 351 )"; 352 353 Module &M = parseModule(ModuleString); 354 Function *F = M.getFunction("foo"); 355 FunctionSpecializer Specializer = getSpecializerFor(F); 356 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 357 358 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 359 360 auto FuncIter = F->begin(); 361 BasicBlock &Loop = *++FuncIter; 362 BasicBlock &Case1 = *++FuncIter; 363 BasicBlock &Case2 = *++FuncIter; 364 BasicBlock &BB = *++FuncIter; 365 366 Instruction &PhiLoop = Loop.front(); 367 Instruction &Switch = Loop.back(); 368 Instruction &Add = Case1.front(); 369 Instruction &PhiCase2 = Case2.front(); 370 Instruction &BrBB = Case2.back(); 371 Instruction &PhiBB = BB.front(); 372 Instruction &Icmp = *++BB.begin(); 373 Instruction &Branch = BB.back(); 374 375 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 376 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 377 378 Test = Visitor.getSpecializationBonus(F->getArg(1), One); 379 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 380 381 // switch + phi + br 382 Bonus Ref = getInstCost(Switch) + 383 getInstCost(PhiCase2, /*SizeOnly =*/ true) + 384 getInstCost(BrBB, /*SizeOnly =*/ true); 385 Test = Visitor.getSpecializationBonus(F->getArg(2), One); 386 EXPECT_EQ(Test, Ref); 387 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 388 389 // phi + phi + add + icmp + branch 390 Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) + 391 getInstCost(Icmp) + getInstCost(Branch); 392 Test = Visitor.getBonusFromPendingPHIs(); 393 EXPECT_EQ(Test, Ref); 394 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 395 } 396 397