1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/AssumptionCache.h" 10 #include "llvm/Analysis/BlockFrequencyInfo.h" 11 #include "llvm/Analysis/BranchProbabilityInfo.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/PostDominators.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/AsmParser/Parser.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/IR/PassInstrumentation.h" 19 #include "llvm/Support/SourceMgr.h" 20 #include "llvm/Transforms/IPO/FunctionSpecialization.h" 21 #include "llvm/Transforms/Utils/SCCPSolver.h" 22 #include "gtest/gtest.h" 23 #include <memory> 24 25 namespace llvm { 26 27 static void removeSSACopy(Function &F) { 28 for (BasicBlock &BB : F) { 29 for (Instruction &Inst : llvm::make_early_inc_range(BB)) { 30 if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { 31 if (II->getIntrinsicID() != Intrinsic::ssa_copy) 32 continue; 33 Inst.replaceAllUsesWith(II->getOperand(0)); 34 Inst.eraseFromParent(); 35 } 36 } 37 } 38 } 39 40 class FunctionSpecializationTest : public testing::Test { 41 protected: 42 LLVMContext Ctx; 43 FunctionAnalysisManager FAM; 44 std::unique_ptr<Module> M; 45 std::unique_ptr<SCCPSolver> Solver; 46 SmallVector<Instruction *, 8> KnownConstants; 47 48 FunctionSpecializationTest() { 49 FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 50 FAM.registerPass([&] { return TargetIRAnalysis(); }); 51 FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 52 FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 53 FAM.registerPass([&] { return LoopAnalysis(); }); 54 FAM.registerPass([&] { return AssumptionAnalysis(); }); 55 FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 56 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 57 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 58 } 59 60 Module &parseModule(const char *ModuleString) { 61 SMDiagnostic Err; 62 M = parseAssemblyString(ModuleString, Err, Ctx); 63 EXPECT_TRUE(M); 64 return *M; 65 } 66 67 FunctionSpecializer getSpecializerFor(Function *F) { 68 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 69 return FAM.getResult<TargetLibraryAnalysis>(F); 70 }; 71 auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 72 return FAM.getResult<TargetIRAnalysis>(F); 73 }; 74 auto GetAC = [this](Function &F) -> AssumptionCache & { 75 return FAM.getResult<AssumptionAnalysis>(F); 76 }; 77 auto GetDT = [this](Function &F) -> DominatorTree & { 78 return FAM.getResult<DominatorTreeAnalysis>(F); 79 }; 80 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 81 return FAM.getResult<BlockFrequencyAnalysis>(F); 82 }; 83 84 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 85 86 DominatorTree &DT = GetDT(*F); 87 AssumptionCache &AC = GetAC(*F); 88 Solver->addPredicateInfo(*F, DT, AC); 89 90 Solver->markBlockExecutable(&F->front()); 91 for (Argument &Arg : F->args()) 92 Solver->markOverdefined(&Arg); 93 Solver->solveWhileResolvedUndefsIn(*M); 94 95 removeSSACopy(*F); 96 97 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 98 GetAC); 99 } 100 101 Cost getCodeSizeSavings(Instruction &I, bool HasLatencySavings = true) { 102 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 103 104 Cost CodeSize = 105 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); 106 107 if (HasLatencySavings) 108 KnownConstants.push_back(&I); 109 110 return CodeSize; 111 } 112 113 Cost getLatencySavings(Function *F) { 114 auto &TTI = FAM.getResult<TargetIRAnalysis>(*F); 115 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*F); 116 117 Cost Latency = 0; 118 for (const Instruction *I : KnownConstants) 119 Latency += BFI.getBlockFreq(I->getParent()).getFrequency() / 120 BFI.getEntryFreq().getFrequency() * 121 TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency); 122 123 return Latency; 124 } 125 }; 126 127 } // namespace llvm 128 129 using namespace llvm; 130 131 TEST_F(FunctionSpecializationTest, SwitchInst) { 132 const char *ModuleString = R"( 133 define void @foo(i32 %a, i32 %b, i32 %i) { 134 entry: 135 br label %loop 136 loop: 137 switch i32 %i, label %default 138 [ i32 1, label %case1 139 i32 2, label %case2 ] 140 case1: 141 %0 = mul i32 %a, 2 142 %1 = sub i32 6, 5 143 br label %bb1 144 case2: 145 %2 = and i32 %b, 3 146 %3 = sdiv i32 8, 2 147 br label %bb2 148 bb1: 149 %4 = add i32 %0, %b 150 br label %loop 151 bb2: 152 %5 = or i32 %2, %a 153 br label %loop 154 default: 155 ret void 156 } 157 )"; 158 159 Module &M = parseModule(ModuleString); 160 Function *F = M.getFunction("foo"); 161 FunctionSpecializer Specializer = getSpecializerFor(F); 162 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 163 164 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 165 166 auto FuncIter = F->begin(); 167 BasicBlock &Loop = *++FuncIter; 168 BasicBlock &Case1 = *++FuncIter; 169 BasicBlock &Case2 = *++FuncIter; 170 BasicBlock &BB1 = *++FuncIter; 171 BasicBlock &BB2 = *++FuncIter; 172 173 Instruction &Switch = Loop.front(); 174 Instruction &Mul = Case1.front(); 175 Instruction &And = Case2.front(); 176 Instruction &Sdiv = *++Case2.begin(); 177 Instruction &BrBB2 = Case2.back(); 178 Instruction &Add = BB1.front(); 179 Instruction &Or = BB2.front(); 180 Instruction &BrLoop = BB2.back(); 181 182 // mul 183 Cost Ref = getCodeSizeSavings(Mul); 184 Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One); 185 EXPECT_EQ(Test, Ref); 186 EXPECT_TRUE(Test > 0); 187 188 // and + or + add 189 Ref = getCodeSizeSavings(And) + getCodeSizeSavings(Or) + 190 getCodeSizeSavings(Add); 191 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One); 192 EXPECT_EQ(Test, Ref); 193 EXPECT_TRUE(Test > 0); 194 195 // switch + sdiv + br + br 196 Ref = getCodeSizeSavings(Switch) + 197 getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) + 198 getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) + 199 getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false); 200 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One); 201 EXPECT_EQ(Test, Ref); 202 EXPECT_TRUE(Test > 0); 203 204 // Latency. 205 Ref = getLatencySavings(F); 206 Test = Visitor.getLatencySavingsForKnownConstants(); 207 EXPECT_EQ(Test, Ref); 208 EXPECT_TRUE(Test > 0); 209 } 210 211 TEST_F(FunctionSpecializationTest, BranchInst) { 212 const char *ModuleString = R"( 213 define void @foo(i32 %a, i32 %b, i1 %cond) { 214 entry: 215 br label %loop 216 loop: 217 br i1 %cond, label %bb0, label %bb3 218 bb0: 219 %0 = mul i32 %a, 2 220 %1 = sub i32 6, 5 221 br i1 %cond, label %bb1, label %bb2 222 bb1: 223 %2 = add i32 %0, %b 224 %3 = sdiv i32 8, 2 225 br label %bb2 226 bb2: 227 br label %loop 228 bb3: 229 ret void 230 } 231 )"; 232 233 Module &M = parseModule(ModuleString); 234 Function *F = M.getFunction("foo"); 235 FunctionSpecializer Specializer = getSpecializerFor(F); 236 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 237 238 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 239 Constant *False = ConstantInt::getFalse(M.getContext()); 240 241 auto FuncIter = F->begin(); 242 BasicBlock &Loop = *++FuncIter; 243 BasicBlock &BB0 = *++FuncIter; 244 BasicBlock &BB1 = *++FuncIter; 245 BasicBlock &BB2 = *++FuncIter; 246 247 Instruction &Branch = Loop.front(); 248 Instruction &Mul = BB0.front(); 249 Instruction &Sub = *++BB0.begin(); 250 Instruction &BrBB1BB2 = BB0.back(); 251 Instruction &Add = BB1.front(); 252 Instruction &Sdiv = *++BB1.begin(); 253 Instruction &BrBB2 = BB1.back(); 254 Instruction &BrLoop = BB2.front(); 255 256 // mul 257 Cost Ref = getCodeSizeSavings(Mul); 258 Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One); 259 EXPECT_EQ(Test, Ref); 260 EXPECT_TRUE(Test > 0); 261 262 // add 263 Ref = getCodeSizeSavings(Add); 264 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One); 265 EXPECT_EQ(Test, Ref); 266 EXPECT_TRUE(Test > 0); 267 268 // branch + sub + br + sdiv + br 269 Ref = getCodeSizeSavings(Branch) + 270 getCodeSizeSavings(Sub, /*HasLatencySavings=*/false) + 271 getCodeSizeSavings(BrBB1BB2) + 272 getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) + 273 getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) + 274 getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false); 275 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), False); 276 EXPECT_EQ(Test, Ref); 277 EXPECT_TRUE(Test > 0); 278 279 // Latency. 280 Ref = getLatencySavings(F); 281 Test = Visitor.getLatencySavingsForKnownConstants(); 282 EXPECT_EQ(Test, Ref); 283 EXPECT_TRUE(Test > 0); 284 } 285 286 TEST_F(FunctionSpecializationTest, SelectInst) { 287 const char *ModuleString = R"( 288 define i32 @foo(i1 %cond, i32 %a, i32 %b) { 289 %sel = select i1 %cond, i32 %a, i32 %b 290 ret i32 %sel 291 } 292 )"; 293 294 Module &M = parseModule(ModuleString); 295 Function *F = M.getFunction("foo"); 296 FunctionSpecializer Specializer = getSpecializerFor(F); 297 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 298 299 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 300 Constant *Zero = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 0); 301 Constant *False = ConstantInt::getFalse(M.getContext()); 302 Instruction &Select = *F->front().begin(); 303 304 Cost RefCodeSize = getCodeSizeSavings(Select); 305 Cost RefLatency = getLatencySavings(F); 306 307 Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False); 308 EXPECT_TRUE(TestCodeSize == 0); 309 TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One); 310 EXPECT_TRUE(TestCodeSize == 0); 311 Cost TestLatency = Visitor.getLatencySavingsForKnownConstants(); 312 EXPECT_TRUE(TestLatency == 0); 313 314 TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(2), Zero); 315 EXPECT_EQ(TestCodeSize, RefCodeSize); 316 EXPECT_TRUE(TestCodeSize > 0); 317 TestLatency = Visitor.getLatencySavingsForKnownConstants(); 318 EXPECT_EQ(TestLatency, RefLatency); 319 EXPECT_TRUE(TestLatency > 0); 320 } 321 322 TEST_F(FunctionSpecializationTest, Misc) { 323 const char *ModuleString = R"( 324 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] } 325 @g = constant %struct_t zeroinitializer, align 16 326 327 declare i32 @llvm.smax.i32(i32, i32) 328 declare i32 @bar(i32) 329 330 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) { 331 %cmp = icmp eq i8 %a, 10 332 %ext = zext i1 %cmp to i64 333 %sel = select i1 %cond, i64 %ext, i64 1 334 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4 335 %ld = load i32, ptr %gep 336 %fr = freeze i32 %ld 337 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1) 338 %call = call i32 @bar(i32 %smax) 339 %fr2 = freeze i32 %c 340 %add = add i32 %call, %fr2 341 ret i32 %add 342 } 343 )"; 344 345 Module &M = parseModule(ModuleString); 346 Function *F = M.getFunction("foo"); 347 FunctionSpecializer Specializer = getSpecializerFor(F); 348 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 349 350 GlobalVariable *GV = M.getGlobalVariable("g"); 351 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 352 Constant *True = ConstantInt::getTrue(M.getContext()); 353 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext())); 354 355 auto BlockIter = F->front().begin(); 356 Instruction &Icmp = *BlockIter++; 357 Instruction &Zext = *BlockIter++; 358 Instruction &Select = *BlockIter++; 359 Instruction &Gep = *BlockIter++; 360 Instruction &Load = *BlockIter++; 361 Instruction &Freeze = *BlockIter++; 362 Instruction &Smax = *BlockIter++; 363 364 // icmp + zext 365 Cost Ref = getCodeSizeSavings(Icmp) + getCodeSizeSavings(Zext); 366 Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One); 367 EXPECT_EQ(Test, Ref); 368 EXPECT_TRUE(Test > 0); 369 370 // select 371 Ref = getCodeSizeSavings(Select); 372 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), True); 373 EXPECT_EQ(Test, Ref); 374 EXPECT_TRUE(Test > 0); 375 376 // gep + load + freeze + smax 377 Ref = getCodeSizeSavings(Gep) + getCodeSizeSavings(Load) + 378 getCodeSizeSavings(Freeze) + getCodeSizeSavings(Smax); 379 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), GV); 380 EXPECT_EQ(Test, Ref); 381 EXPECT_TRUE(Test > 0); 382 383 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(3), Undef); 384 EXPECT_TRUE(Test == 0); 385 386 // Latency. 387 Ref = getLatencySavings(F); 388 Test = Visitor.getLatencySavingsForKnownConstants(); 389 EXPECT_EQ(Test, Ref); 390 EXPECT_TRUE(Test > 0); 391 } 392 393 TEST_F(FunctionSpecializationTest, PhiNode) { 394 const char *ModuleString = R"( 395 define void @foo(i32 %a, i32 %b, i32 %i) { 396 entry: 397 br label %loop 398 loop: 399 %0 = phi i32 [ %a, %entry ], [ %3, %bb ] 400 switch i32 %i, label %default 401 [ i32 1, label %case1 402 i32 2, label %case2 ] 403 case1: 404 %1 = add i32 %0, 1 405 br label %bb 406 case2: 407 %2 = phi i32 [ %a, %entry ], [ %0, %loop ] 408 br label %bb 409 bb: 410 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ] 411 %4 = icmp eq i32 %3, 1 412 br i1 %4, label %bb, label %loop 413 default: 414 ret void 415 } 416 )"; 417 418 Module &M = parseModule(ModuleString); 419 Function *F = M.getFunction("foo"); 420 FunctionSpecializer Specializer = getSpecializerFor(F); 421 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 422 423 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 424 425 auto FuncIter = F->begin(); 426 BasicBlock &Loop = *++FuncIter; 427 BasicBlock &Case1 = *++FuncIter; 428 BasicBlock &Case2 = *++FuncIter; 429 BasicBlock &BB = *++FuncIter; 430 431 Instruction &PhiLoop = Loop.front(); 432 Instruction &Switch = Loop.back(); 433 Instruction &Add = Case1.front(); 434 Instruction &PhiCase2 = Case2.front(); 435 Instruction &BrBB = Case2.back(); 436 Instruction &PhiBB = BB.front(); 437 Instruction &Icmp = *++BB.begin(); 438 Instruction &Branch = BB.back(); 439 440 Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One); 441 EXPECT_TRUE(Test == 0); 442 443 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One); 444 EXPECT_TRUE(Test == 0); 445 446 Test = Visitor.getLatencySavingsForKnownConstants(); 447 EXPECT_TRUE(Test == 0); 448 449 // switch + phi + br 450 Cost Ref = getCodeSizeSavings(Switch) + 451 getCodeSizeSavings(PhiCase2, /*HasLatencySavings=*/false) + 452 getCodeSizeSavings(BrBB, /*HasLatencySavings=*/false); 453 Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One); 454 EXPECT_EQ(Test, Ref); 455 EXPECT_TRUE(Test > 0 && Test > 0); 456 457 // phi + phi + add + icmp + branch 458 Ref = getCodeSizeSavings(PhiBB) + getCodeSizeSavings(PhiLoop) + 459 getCodeSizeSavings(Add) + getCodeSizeSavings(Icmp) + 460 getCodeSizeSavings(Branch); 461 Test = Visitor.getCodeSizeSavingsFromPendingPHIs(); 462 EXPECT_EQ(Test, Ref); 463 EXPECT_TRUE(Test > 0); 464 465 // Latency. 466 Ref = getLatencySavings(F); 467 Test = Visitor.getLatencySavingsForKnownConstants(); 468 EXPECT_EQ(Test, Ref); 469 EXPECT_TRUE(Test > 0); 470 } 471 472 TEST_F(FunctionSpecializationTest, BinOp) { 473 // Verify that we can handle binary operators even when only one operand is 474 // constant. 475 const char *ModuleString = R"( 476 define i32 @foo(i1 %a, i1 %b) { 477 %and1 = and i1 %a, %b 478 %and2 = and i1 %b, %and1 479 %sel = select i1 %and2, i32 1, i32 0 480 ret i32 %sel 481 } 482 )"; 483 484 Module &M = parseModule(ModuleString); 485 Function *F = M.getFunction("foo"); 486 FunctionSpecializer Specializer = getSpecializerFor(F); 487 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 488 489 Constant *False = ConstantInt::getFalse(M.getContext()); 490 BasicBlock &BB = F->front(); 491 Instruction &And1 = BB.front(); 492 Instruction &And2 = *++BB.begin(); 493 Instruction &Select = *++BB.begin(); 494 495 Cost RefCodeSize = getCodeSizeSavings(And1) + getCodeSizeSavings(And2) + 496 getCodeSizeSavings(Select); 497 Cost RefLatency = getLatencySavings(F); 498 499 Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False); 500 Cost TestLatency = Visitor.getLatencySavingsForKnownConstants(); 501 502 EXPECT_EQ(TestCodeSize, RefCodeSize); 503 EXPECT_TRUE(TestCodeSize > 0); 504 EXPECT_EQ(TestLatency, RefLatency); 505 EXPECT_TRUE(TestLatency > 0); 506 } 507