1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/AssumptionCache.h" 10 #include "llvm/Analysis/BlockFrequencyInfo.h" 11 #include "llvm/Analysis/BranchProbabilityInfo.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/PostDominators.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/AsmParser/Parser.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "llvm/Transforms/IPO/FunctionSpecialization.h" 20 #include "llvm/Transforms/Utils/SCCPSolver.h" 21 #include "gtest/gtest.h" 22 #include <memory> 23 24 namespace llvm { 25 26 static void removeSSACopy(Function &F) { 27 for (BasicBlock &BB : F) { 28 for (Instruction &Inst : llvm::make_early_inc_range(BB)) { 29 if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { 30 if (II->getIntrinsicID() != Intrinsic::ssa_copy) 31 continue; 32 Inst.replaceAllUsesWith(II->getOperand(0)); 33 Inst.eraseFromParent(); 34 } 35 } 36 } 37 } 38 39 class FunctionSpecializationTest : public testing::Test { 40 protected: 41 LLVMContext Ctx; 42 FunctionAnalysisManager FAM; 43 std::unique_ptr<Module> M; 44 std::unique_ptr<SCCPSolver> Solver; 45 46 FunctionSpecializationTest() { 47 FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 48 FAM.registerPass([&] { return TargetIRAnalysis(); }); 49 FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 50 FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 51 FAM.registerPass([&] { return LoopAnalysis(); }); 52 FAM.registerPass([&] { return AssumptionAnalysis(); }); 53 FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 54 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 55 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 56 } 57 58 Module &parseModule(const char *ModuleString) { 59 SMDiagnostic Err; 60 M = parseAssemblyString(ModuleString, Err, Ctx); 61 EXPECT_TRUE(M); 62 return *M; 63 } 64 65 FunctionSpecializer getSpecializerFor(Function *F) { 66 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 67 return FAM.getResult<TargetLibraryAnalysis>(F); 68 }; 69 auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 70 return FAM.getResult<TargetIRAnalysis>(F); 71 }; 72 auto GetAC = [this](Function &F) -> AssumptionCache & { 73 return FAM.getResult<AssumptionAnalysis>(F); 74 }; 75 auto GetDT = [this](Function &F) -> DominatorTree & { 76 return FAM.getResult<DominatorTreeAnalysis>(F); 77 }; 78 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 79 return FAM.getResult<BlockFrequencyAnalysis>(F); 80 }; 81 82 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 83 84 DominatorTree &DT = GetDT(*F); 85 AssumptionCache &AC = GetAC(*F); 86 Solver->addPredicateInfo(*F, DT, AC); 87 88 Solver->markBlockExecutable(&F->front()); 89 for (Argument &Arg : F->args()) 90 Solver->markOverdefined(&Arg); 91 Solver->solveWhileResolvedUndefsIn(*M); 92 93 removeSSACopy(*F); 94 95 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 96 GetAC); 97 } 98 99 Bonus getInstCost(Instruction &I, bool SizeOnly = false) { 100 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 101 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction()); 102 103 Cost CodeSize = 104 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); 105 106 Cost Latency = SizeOnly ? 0 : 107 BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() * 108 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency); 109 110 return {CodeSize, Latency}; 111 } 112 }; 113 114 } // namespace llvm 115 116 using namespace llvm; 117 118 TEST_F(FunctionSpecializationTest, SwitchInst) { 119 const char *ModuleString = R"( 120 define void @foo(i32 %a, i32 %b, i32 %i) { 121 entry: 122 br label %loop 123 loop: 124 switch i32 %i, label %default 125 [ i32 1, label %case1 126 i32 2, label %case2 ] 127 case1: 128 %0 = mul i32 %a, 2 129 %1 = sub i32 6, 5 130 br label %bb1 131 case2: 132 %2 = and i32 %b, 3 133 %3 = sdiv i32 8, 2 134 br label %bb2 135 bb1: 136 %4 = add i32 %0, %b 137 br label %loop 138 bb2: 139 %5 = or i32 %2, %a 140 br label %loop 141 default: 142 ret void 143 } 144 )"; 145 146 Module &M = parseModule(ModuleString); 147 Function *F = M.getFunction("foo"); 148 FunctionSpecializer Specializer = getSpecializerFor(F); 149 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 150 151 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 152 153 auto FuncIter = F->begin(); 154 BasicBlock &Loop = *++FuncIter; 155 BasicBlock &Case1 = *++FuncIter; 156 BasicBlock &Case2 = *++FuncIter; 157 BasicBlock &BB1 = *++FuncIter; 158 BasicBlock &BB2 = *++FuncIter; 159 160 Instruction &Switch = Loop.front(); 161 Instruction &Mul = Case1.front(); 162 Instruction &And = Case2.front(); 163 Instruction &Sdiv = *++Case2.begin(); 164 Instruction &BrBB2 = Case2.back(); 165 Instruction &Add = BB1.front(); 166 Instruction &Or = BB2.front(); 167 Instruction &BrLoop = BB2.back(); 168 169 // mul 170 Bonus Ref = getInstCost(Mul); 171 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 172 EXPECT_EQ(Test, Ref); 173 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 174 175 // and + or + add 176 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add); 177 Test = Visitor.getSpecializationBonus(F->getArg(1), One); 178 EXPECT_EQ(Test, Ref); 179 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 180 181 // switch + sdiv + br + br 182 Ref = getInstCost(Switch) + 183 getInstCost(Sdiv, /*SizeOnly =*/ true) + 184 getInstCost(BrBB2, /*SizeOnly =*/ true) + 185 getInstCost(BrLoop, /*SizeOnly =*/ true); 186 Test = Visitor.getSpecializationBonus(F->getArg(2), One); 187 EXPECT_EQ(Test, Ref); 188 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 189 } 190 191 TEST_F(FunctionSpecializationTest, BranchInst) { 192 const char *ModuleString = R"( 193 define void @foo(i32 %a, i32 %b, i1 %cond) { 194 entry: 195 br label %loop 196 loop: 197 br i1 %cond, label %bb0, label %bb3 198 bb0: 199 %0 = mul i32 %a, 2 200 %1 = sub i32 6, 5 201 br i1 %cond, label %bb1, label %bb2 202 bb1: 203 %2 = add i32 %0, %b 204 %3 = sdiv i32 8, 2 205 br label %bb2 206 bb2: 207 br label %loop 208 bb3: 209 ret void 210 } 211 )"; 212 213 Module &M = parseModule(ModuleString); 214 Function *F = M.getFunction("foo"); 215 FunctionSpecializer Specializer = getSpecializerFor(F); 216 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 217 218 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 219 Constant *False = ConstantInt::getFalse(M.getContext()); 220 221 auto FuncIter = F->begin(); 222 BasicBlock &Loop = *++FuncIter; 223 BasicBlock &BB0 = *++FuncIter; 224 BasicBlock &BB1 = *++FuncIter; 225 BasicBlock &BB2 = *++FuncIter; 226 227 Instruction &Branch = Loop.front(); 228 Instruction &Mul = BB0.front(); 229 Instruction &Sub = *++BB0.begin(); 230 Instruction &BrBB1BB2 = BB0.back(); 231 Instruction &Add = BB1.front(); 232 Instruction &Sdiv = *++BB1.begin(); 233 Instruction &BrBB2 = BB1.back(); 234 Instruction &BrLoop = BB2.front(); 235 236 // mul 237 Bonus Ref = getInstCost(Mul); 238 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 239 EXPECT_EQ(Test, Ref); 240 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 241 242 // add 243 Ref = getInstCost(Add); 244 Test = Visitor.getSpecializationBonus(F->getArg(1), One); 245 EXPECT_EQ(Test, Ref); 246 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 247 248 // branch + sub + br + sdiv + br 249 Ref = getInstCost(Branch) + 250 getInstCost(Sub, /*SizeOnly =*/ true) + 251 getInstCost(BrBB1BB2) + 252 getInstCost(Sdiv, /*SizeOnly =*/ true) + 253 getInstCost(BrBB2, /*SizeOnly =*/ true) + 254 getInstCost(BrLoop, /*SizeOnly =*/ true); 255 Test = Visitor.getSpecializationBonus(F->getArg(2), False); 256 EXPECT_EQ(Test, Ref); 257 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 258 } 259 260 TEST_F(FunctionSpecializationTest, Misc) { 261 const char *ModuleString = R"( 262 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] } 263 @g = constant %struct_t zeroinitializer, align 16 264 265 declare i32 @llvm.smax.i32(i32, i32) 266 declare i32 @bar(i32) 267 268 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) { 269 %cmp = icmp eq i8 %a, 10 270 %ext = zext i1 %cmp to i64 271 %sel = select i1 %cond, i64 %ext, i64 1 272 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4 273 %ld = load i32, ptr %gep 274 %fr = freeze i32 %ld 275 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1) 276 %call = call i32 @bar(i32 %smax) 277 %fr2 = freeze i32 %c 278 %add = add i32 %call, %fr2 279 ret i32 %add 280 } 281 )"; 282 283 Module &M = parseModule(ModuleString); 284 Function *F = M.getFunction("foo"); 285 FunctionSpecializer Specializer = getSpecializerFor(F); 286 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 287 288 GlobalVariable *GV = M.getGlobalVariable("g"); 289 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 290 Constant *True = ConstantInt::getTrue(M.getContext()); 291 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext())); 292 293 auto BlockIter = F->front().begin(); 294 Instruction &Icmp = *BlockIter++; 295 Instruction &Zext = *BlockIter++; 296 Instruction &Select = *BlockIter++; 297 Instruction &Gep = *BlockIter++; 298 Instruction &Load = *BlockIter++; 299 Instruction &Freeze = *BlockIter++; 300 Instruction &Smax = *BlockIter++; 301 302 // icmp + zext 303 Bonus Ref = getInstCost(Icmp) + getInstCost(Zext); 304 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 305 EXPECT_EQ(Test, Ref); 306 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 307 308 // select 309 Ref = getInstCost(Select); 310 Test = Visitor.getSpecializationBonus(F->getArg(1), True); 311 EXPECT_EQ(Test, Ref); 312 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 313 314 // gep + load + freeze + smax 315 Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) + 316 getInstCost(Smax); 317 Test = Visitor.getSpecializationBonus(F->getArg(2), GV); 318 EXPECT_EQ(Test, Ref); 319 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 320 321 Test = Visitor.getSpecializationBonus(F->getArg(3), Undef); 322 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 323 } 324 325 TEST_F(FunctionSpecializationTest, PhiNode) { 326 const char *ModuleString = R"( 327 define void @foo(i32 %a, i32 %b, i32 %i) { 328 entry: 329 br label %loop 330 loop: 331 %0 = phi i32 [ %a, %entry ], [ %3, %bb ] 332 switch i32 %i, label %default 333 [ i32 1, label %case1 334 i32 2, label %case2 ] 335 case1: 336 %1 = add i32 %0, 1 337 br label %bb 338 case2: 339 %2 = phi i32 [ %a, %entry ], [ %0, %loop ] 340 br label %bb 341 bb: 342 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ] 343 %4 = icmp eq i32 %3, 1 344 br i1 %4, label %bb, label %loop 345 default: 346 ret void 347 } 348 )"; 349 350 Module &M = parseModule(ModuleString); 351 Function *F = M.getFunction("foo"); 352 FunctionSpecializer Specializer = getSpecializerFor(F); 353 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 354 355 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 356 357 auto FuncIter = F->begin(); 358 BasicBlock &Loop = *++FuncIter; 359 BasicBlock &Case1 = *++FuncIter; 360 BasicBlock &Case2 = *++FuncIter; 361 BasicBlock &BB = *++FuncIter; 362 363 Instruction &PhiLoop = Loop.front(); 364 Instruction &Switch = Loop.back(); 365 Instruction &Add = Case1.front(); 366 Instruction &PhiCase2 = Case2.front(); 367 Instruction &BrBB = Case2.back(); 368 Instruction &PhiBB = BB.front(); 369 Instruction &Icmp = *++BB.begin(); 370 Instruction &Branch = BB.back(); 371 372 Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One); 373 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 374 375 Test = Visitor.getSpecializationBonus(F->getArg(1), One); 376 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 377 378 // switch + phi + br 379 Bonus Ref = getInstCost(Switch) + 380 getInstCost(PhiCase2, /*SizeOnly =*/ true) + 381 getInstCost(BrBB, /*SizeOnly =*/ true); 382 Test = Visitor.getSpecializationBonus(F->getArg(2), One); 383 EXPECT_EQ(Test, Ref); 384 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 385 386 // phi + phi + add + icmp + branch 387 Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) + 388 getInstCost(Icmp) + getInstCost(Branch); 389 Test = Visitor.getBonusFromPendingPHIs(); 390 EXPECT_EQ(Test, Ref); 391 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 392 } 393 394