1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/AssumptionCache.h" 10 #include "llvm/Analysis/BlockFrequencyInfo.h" 11 #include "llvm/Analysis/BranchProbabilityInfo.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/PostDominators.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/AsmParser/Parser.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/IR/PassInstrumentation.h" 19 #include "llvm/Support/SourceMgr.h" 20 #include "llvm/Transforms/IPO/FunctionSpecialization.h" 21 #include "llvm/Transforms/Utils/SCCPSolver.h" 22 #include "gtest/gtest.h" 23 #include <memory> 24 25 namespace llvm { 26 27 static void removeSSACopy(Function &F) { 28 for (BasicBlock &BB : F) { 29 for (Instruction &Inst : llvm::make_early_inc_range(BB)) { 30 if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { 31 if (II->getIntrinsicID() != Intrinsic::ssa_copy) 32 continue; 33 Inst.replaceAllUsesWith(II->getOperand(0)); 34 Inst.eraseFromParent(); 35 } 36 } 37 } 38 } 39 40 class FunctionSpecializationTest : public testing::Test { 41 protected: 42 LLVMContext Ctx; 43 FunctionAnalysisManager FAM; 44 std::unique_ptr<Module> M; 45 std::unique_ptr<SCCPSolver> Solver; 46 47 FunctionSpecializationTest() { 48 FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 49 FAM.registerPass([&] { return TargetIRAnalysis(); }); 50 FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 51 FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 52 FAM.registerPass([&] { return LoopAnalysis(); }); 53 FAM.registerPass([&] { return AssumptionAnalysis(); }); 54 FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 55 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 56 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 57 } 58 59 Module &parseModule(const char *ModuleString) { 60 SMDiagnostic Err; 61 M = parseAssemblyString(ModuleString, Err, Ctx); 62 EXPECT_TRUE(M); 63 return *M; 64 } 65 66 FunctionSpecializer getSpecializerFor(Function *F) { 67 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 68 return FAM.getResult<TargetLibraryAnalysis>(F); 69 }; 70 auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 71 return FAM.getResult<TargetIRAnalysis>(F); 72 }; 73 auto GetAC = [this](Function &F) -> AssumptionCache & { 74 return FAM.getResult<AssumptionAnalysis>(F); 75 }; 76 auto GetDT = [this](Function &F) -> DominatorTree & { 77 return FAM.getResult<DominatorTreeAnalysis>(F); 78 }; 79 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 80 return FAM.getResult<BlockFrequencyAnalysis>(F); 81 }; 82 83 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 84 85 DominatorTree &DT = GetDT(*F); 86 AssumptionCache &AC = GetAC(*F); 87 Solver->addPredicateInfo(*F, DT, AC); 88 89 Solver->markBlockExecutable(&F->front()); 90 for (Argument &Arg : F->args()) 91 Solver->markOverdefined(&Arg); 92 Solver->solveWhileResolvedUndefsIn(*M); 93 94 removeSSACopy(*F); 95 96 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 97 GetAC); 98 } 99 100 Bonus getInstCost(Instruction &I, bool SizeOnly = false) { 101 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 102 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction()); 103 104 Cost CodeSize = 105 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); 106 107 Cost Latency = 108 SizeOnly 109 ? 0 110 : BFI.getBlockFreq(I.getParent()).getFrequency() / 111 BFI.getEntryFreq().getFrequency() * 112 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency); 113 114 return {CodeSize, Latency}; 115 } 116 }; 117 118 } // namespace llvm 119 120 using namespace llvm; 121 122 TEST_F(FunctionSpecializationTest, SwitchInst) { 123 const char *ModuleString = R"( 124 define void @foo(i32 %a, i32 %b, i32 %i) { 125 entry: 126 br label %loop 127 loop: 128 switch i32 %i, label %default 129 [ i32 1, label %case1 130 i32 2, label %case2 ] 131 case1: 132 %0 = mul i32 %a, 2 133 %1 = sub i32 6, 5 134 br label %bb1 135 case2: 136 %2 = and i32 %b, 3 137 %3 = sdiv i32 8, 2 138 br label %bb2 139 bb1: 140 %4 = add i32 %0, %b 141 br label %loop 142 bb2: 143 %5 = or i32 %2, %a 144 br label %loop 145 default: 146 ret void 147 } 148 )"; 149 150 Module &M = parseModule(ModuleString); 151 Function *F = M.getFunction("foo"); 152 FunctionSpecializer Specializer = getSpecializerFor(F); 153 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 154 155 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 156 157 auto FuncIter = F->begin(); 158 BasicBlock &Loop = *++FuncIter; 159 BasicBlock &Case1 = *++FuncIter; 160 BasicBlock &Case2 = *++FuncIter; 161 BasicBlock &BB1 = *++FuncIter; 162 BasicBlock &BB2 = *++FuncIter; 163 164 Instruction &Switch = Loop.front(); 165 Instruction &Mul = Case1.front(); 166 Instruction &And = Case2.front(); 167 Instruction &Sdiv = *++Case2.begin(); 168 Instruction &BrBB2 = Case2.back(); 169 Instruction &Add = BB1.front(); 170 Instruction &Or = BB2.front(); 171 Instruction &BrLoop = BB2.back(); 172 173 // mul 174 Bonus Ref = getInstCost(Mul); 175 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 176 EXPECT_EQ(Test, Ref); 177 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 178 179 // and + or + add 180 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add); 181 Test = Visitor.getSpecializationBonus(F->getArg(1), One); 182 EXPECT_EQ(Test, Ref); 183 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 184 185 // switch + sdiv + br + br 186 Ref = getInstCost(Switch) + 187 getInstCost(Sdiv, /*SizeOnly =*/ true) + 188 getInstCost(BrBB2, /*SizeOnly =*/ true) + 189 getInstCost(BrLoop, /*SizeOnly =*/ true); 190 Test = Visitor.getSpecializationBonus(F->getArg(2), One); 191 EXPECT_EQ(Test, Ref); 192 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 193 } 194 195 TEST_F(FunctionSpecializationTest, BranchInst) { 196 const char *ModuleString = R"( 197 define void @foo(i32 %a, i32 %b, i1 %cond) { 198 entry: 199 br label %loop 200 loop: 201 br i1 %cond, label %bb0, label %bb3 202 bb0: 203 %0 = mul i32 %a, 2 204 %1 = sub i32 6, 5 205 br i1 %cond, label %bb1, label %bb2 206 bb1: 207 %2 = add i32 %0, %b 208 %3 = sdiv i32 8, 2 209 br label %bb2 210 bb2: 211 br label %loop 212 bb3: 213 ret void 214 } 215 )"; 216 217 Module &M = parseModule(ModuleString); 218 Function *F = M.getFunction("foo"); 219 FunctionSpecializer Specializer = getSpecializerFor(F); 220 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 221 222 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 223 Constant *False = ConstantInt::getFalse(M.getContext()); 224 225 auto FuncIter = F->begin(); 226 BasicBlock &Loop = *++FuncIter; 227 BasicBlock &BB0 = *++FuncIter; 228 BasicBlock &BB1 = *++FuncIter; 229 BasicBlock &BB2 = *++FuncIter; 230 231 Instruction &Branch = Loop.front(); 232 Instruction &Mul = BB0.front(); 233 Instruction &Sub = *++BB0.begin(); 234 Instruction &BrBB1BB2 = BB0.back(); 235 Instruction &Add = BB1.front(); 236 Instruction &Sdiv = *++BB1.begin(); 237 Instruction &BrBB2 = BB1.back(); 238 Instruction &BrLoop = BB2.front(); 239 240 // mul 241 Bonus Ref = getInstCost(Mul); 242 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 243 EXPECT_EQ(Test, Ref); 244 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 245 246 // add 247 Ref = getInstCost(Add); 248 Test = Visitor.getSpecializationBonus(F->getArg(1), One); 249 EXPECT_EQ(Test, Ref); 250 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 251 252 // branch + sub + br + sdiv + br 253 Ref = getInstCost(Branch) + 254 getInstCost(Sub, /*SizeOnly =*/ true) + 255 getInstCost(BrBB1BB2) + 256 getInstCost(Sdiv, /*SizeOnly =*/ true) + 257 getInstCost(BrBB2, /*SizeOnly =*/ true) + 258 getInstCost(BrLoop, /*SizeOnly =*/ true); 259 Test = Visitor.getSpecializationBonus(F->getArg(2), False); 260 EXPECT_EQ(Test, Ref); 261 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 262 } 263 264 TEST_F(FunctionSpecializationTest, Misc) { 265 const char *ModuleString = R"( 266 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] } 267 @g = constant %struct_t zeroinitializer, align 16 268 269 declare i32 @llvm.smax.i32(i32, i32) 270 declare i32 @bar(i32) 271 272 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) { 273 %cmp = icmp eq i8 %a, 10 274 %ext = zext i1 %cmp to i64 275 %sel = select i1 %cond, i64 %ext, i64 1 276 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4 277 %ld = load i32, ptr %gep 278 %fr = freeze i32 %ld 279 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1) 280 %call = call i32 @bar(i32 %smax) 281 %fr2 = freeze i32 %c 282 %add = add i32 %call, %fr2 283 ret i32 %add 284 } 285 )"; 286 287 Module &M = parseModule(ModuleString); 288 Function *F = M.getFunction("foo"); 289 FunctionSpecializer Specializer = getSpecializerFor(F); 290 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 291 292 GlobalVariable *GV = M.getGlobalVariable("g"); 293 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 294 Constant *True = ConstantInt::getTrue(M.getContext()); 295 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext())); 296 297 auto BlockIter = F->front().begin(); 298 Instruction &Icmp = *BlockIter++; 299 Instruction &Zext = *BlockIter++; 300 Instruction &Select = *BlockIter++; 301 Instruction &Gep = *BlockIter++; 302 Instruction &Load = *BlockIter++; 303 Instruction &Freeze = *BlockIter++; 304 Instruction &Smax = *BlockIter++; 305 306 // icmp + zext 307 Bonus Ref = getInstCost(Icmp) + getInstCost(Zext); 308 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 309 EXPECT_EQ(Test, Ref); 310 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 311 312 // select 313 Ref = getInstCost(Select); 314 Test = Visitor.getSpecializationBonus(F->getArg(1), True); 315 EXPECT_EQ(Test, Ref); 316 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 317 318 // gep + load + freeze + smax 319 Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) + 320 getInstCost(Smax); 321 Test = Visitor.getSpecializationBonus(F->getArg(2), GV); 322 EXPECT_EQ(Test, Ref); 323 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 324 325 Test = Visitor.getSpecializationBonus(F->getArg(3), Undef); 326 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 327 } 328 329 TEST_F(FunctionSpecializationTest, PhiNode) { 330 const char *ModuleString = R"( 331 define void @foo(i32 %a, i32 %b, i32 %i) { 332 entry: 333 br label %loop 334 loop: 335 %0 = phi i32 [ %a, %entry ], [ %3, %bb ] 336 switch i32 %i, label %default 337 [ i32 1, label %case1 338 i32 2, label %case2 ] 339 case1: 340 %1 = add i32 %0, 1 341 br label %bb 342 case2: 343 %2 = phi i32 [ %a, %entry ], [ %0, %loop ] 344 br label %bb 345 bb: 346 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ] 347 %4 = icmp eq i32 %3, 1 348 br i1 %4, label %bb, label %loop 349 default: 350 ret void 351 } 352 )"; 353 354 Module &M = parseModule(ModuleString); 355 Function *F = M.getFunction("foo"); 356 FunctionSpecializer Specializer = getSpecializerFor(F); 357 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 358 359 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 360 361 auto FuncIter = F->begin(); 362 BasicBlock &Loop = *++FuncIter; 363 BasicBlock &Case1 = *++FuncIter; 364 BasicBlock &Case2 = *++FuncIter; 365 BasicBlock &BB = *++FuncIter; 366 367 Instruction &PhiLoop = Loop.front(); 368 Instruction &Switch = Loop.back(); 369 Instruction &Add = Case1.front(); 370 Instruction &PhiCase2 = Case2.front(); 371 Instruction &BrBB = Case2.back(); 372 Instruction &PhiBB = BB.front(); 373 Instruction &Icmp = *++BB.begin(); 374 Instruction &Branch = BB.back(); 375 376 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 377 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 378 379 Test = Visitor.getSpecializationBonus(F->getArg(1), One); 380 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 381 382 // switch + phi + br 383 Bonus Ref = getInstCost(Switch) + 384 getInstCost(PhiCase2, /*SizeOnly =*/ true) + 385 getInstCost(BrBB, /*SizeOnly =*/ true); 386 Test = Visitor.getSpecializationBonus(F->getArg(2), One); 387 EXPECT_EQ(Test, Ref); 388 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 389 390 // phi + phi + add + icmp + branch 391 Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) + 392 getInstCost(Icmp) + getInstCost(Branch); 393 Test = Visitor.getBonusFromPendingPHIs(); 394 EXPECT_EQ(Test, Ref); 395 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 396 } 397 398