1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/AssumptionCache.h" 10 #include "llvm/Analysis/BlockFrequencyInfo.h" 11 #include "llvm/Analysis/BranchProbabilityInfo.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/PostDominators.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/AsmParser/Parser.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "llvm/Transforms/IPO/FunctionSpecialization.h" 20 #include "llvm/Transforms/Utils/SCCPSolver.h" 21 #include "gtest/gtest.h" 22 #include <memory> 23 24 namespace llvm { 25 26 static void removeSSACopy(Function &F) { 27 for (BasicBlock &BB : F) { 28 for (Instruction &Inst : llvm::make_early_inc_range(BB)) { 29 if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { 30 if (II->getIntrinsicID() != Intrinsic::ssa_copy) 31 continue; 32 Inst.replaceAllUsesWith(II->getOperand(0)); 33 Inst.eraseFromParent(); 34 } 35 } 36 } 37 } 38 39 class FunctionSpecializationTest : public testing::Test { 40 protected: 41 LLVMContext Ctx; 42 FunctionAnalysisManager FAM; 43 std::unique_ptr<Module> M; 44 std::unique_ptr<SCCPSolver> Solver; 45 46 FunctionSpecializationTest() { 47 FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 48 FAM.registerPass([&] { return TargetIRAnalysis(); }); 49 FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 50 FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 51 FAM.registerPass([&] { return LoopAnalysis(); }); 52 FAM.registerPass([&] { return AssumptionAnalysis(); }); 53 FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 54 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 55 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 56 } 57 58 Module &parseModule(const char *ModuleString) { 59 SMDiagnostic Err; 60 M = parseAssemblyString(ModuleString, Err, Ctx); 61 EXPECT_TRUE(M); 62 return *M; 63 } 64 65 FunctionSpecializer getSpecializerFor(Function *F) { 66 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 67 return FAM.getResult<TargetLibraryAnalysis>(F); 68 }; 69 auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 70 return FAM.getResult<TargetIRAnalysis>(F); 71 }; 72 auto GetAC = [this](Function &F) -> AssumptionCache & { 73 return FAM.getResult<AssumptionAnalysis>(F); 74 }; 75 auto GetDT = [this](Function &F) -> DominatorTree & { 76 return FAM.getResult<DominatorTreeAnalysis>(F); 77 }; 78 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 79 return FAM.getResult<BlockFrequencyAnalysis>(F); 80 }; 81 82 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 83 84 DominatorTree &DT = GetDT(*F); 85 AssumptionCache &AC = GetAC(*F); 86 Solver->addPredicateInfo(*F, DT, AC); 87 88 Solver->markBlockExecutable(&F->front()); 89 for (Argument &Arg : F->args()) 90 Solver->markOverdefined(&Arg); 91 Solver->solveWhileResolvedUndefsIn(*M); 92 93 removeSSACopy(*F); 94 95 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 96 GetAC); 97 } 98 99 Bonus getInstCost(Instruction &I, bool SizeOnly = false) { 100 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 101 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction()); 102 103 Cost CodeSize = 104 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); 105 106 Cost Latency = SizeOnly ? 0 : 107 BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() * 108 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency); 109 110 return {CodeSize, Latency}; 111 } 112 }; 113 114 } // namespace llvm 115 116 using namespace llvm; 117 118 TEST_F(FunctionSpecializationTest, SwitchInst) { 119 const char *ModuleString = R"( 120 define void @foo(i32 %a, i32 %b, i32 %i) { 121 entry: 122 br label %loop 123 loop: 124 switch i32 %i, label %default 125 [ i32 1, label %case1 126 i32 2, label %case2 ] 127 case1: 128 %0 = mul i32 %a, 2 129 %1 = sub i32 6, 5 130 br label %bb1 131 case2: 132 %2 = and i32 %b, 3 133 %3 = sdiv i32 8, 2 134 br label %bb2 135 bb1: 136 %4 = add i32 %0, %b 137 br label %loop 138 bb2: 139 %5 = or i32 %2, %a 140 br label %loop 141 default: 142 ret void 143 } 144 )"; 145 146 Module &M = parseModule(ModuleString); 147 Function *F = M.getFunction("foo"); 148 FunctionSpecializer Specializer = getSpecializerFor(F); 149 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 150 151 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 152 153 auto FuncIter = F->begin(); 154 BasicBlock &Loop = *++FuncIter; 155 BasicBlock &Case1 = *++FuncIter; 156 BasicBlock &Case2 = *++FuncIter; 157 BasicBlock &BB1 = *++FuncIter; 158 BasicBlock &BB2 = *++FuncIter; 159 160 Instruction &Switch = Loop.front(); 161 Instruction &Mul = Case1.front(); 162 Instruction &And = Case2.front(); 163 Instruction &Sdiv = *++Case2.begin(); 164 Instruction &BrBB2 = Case2.back(); 165 Instruction &Add = BB1.front(); 166 Instruction &Or = BB2.front(); 167 Instruction &BrLoop = BB2.back(); 168 169 // mul 170 Bonus Ref = getInstCost(Mul); 171 Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 172 EXPECT_EQ(Test, Ref); 173 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 174 175 // and + or + add 176 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add); 177 Test = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 178 EXPECT_EQ(Test, Ref); 179 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 180 181 // switch + sdiv + br + br 182 Ref = getInstCost(Switch) + 183 getInstCost(Sdiv, /*SizeOnly =*/ true) + 184 getInstCost(BrBB2, /*SizeOnly =*/ true) + 185 getInstCost(BrLoop, /*SizeOnly =*/ true); 186 Test = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor); 187 EXPECT_EQ(Test, Ref); 188 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 189 } 190 191 TEST_F(FunctionSpecializationTest, BranchInst) { 192 const char *ModuleString = R"( 193 define void @foo(i32 %a, i32 %b, i1 %cond) { 194 entry: 195 br label %loop 196 loop: 197 br i1 %cond, label %bb0, label %bb2 198 bb0: 199 %0 = mul i32 %a, 2 200 %1 = sub i32 6, 5 201 br label %bb1 202 bb1: 203 %2 = add i32 %0, %b 204 %3 = sdiv i32 8, 2 205 br label %loop 206 bb2: 207 ret void 208 } 209 )"; 210 211 Module &M = parseModule(ModuleString); 212 Function *F = M.getFunction("foo"); 213 FunctionSpecializer Specializer = getSpecializerFor(F); 214 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 215 216 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 217 Constant *False = ConstantInt::getFalse(M.getContext()); 218 219 auto FuncIter = F->begin(); 220 BasicBlock &Loop = *++FuncIter; 221 BasicBlock &BB0 = *++FuncIter; 222 BasicBlock &BB1 = *++FuncIter; 223 224 Instruction &Branch = Loop.front(); 225 Instruction &Mul = BB0.front(); 226 Instruction &Sub = *++BB0.begin(); 227 Instruction &BrBB1 = BB0.back(); 228 Instruction &Add = BB1.front(); 229 Instruction &Sdiv = *++BB1.begin(); 230 Instruction &BrLoop = BB1.back(); 231 232 // mul 233 Bonus Ref = getInstCost(Mul); 234 Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 235 EXPECT_EQ(Test, Ref); 236 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 237 238 // add 239 Ref = getInstCost(Add); 240 Test = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 241 EXPECT_EQ(Test, Ref); 242 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 243 244 // branch + sub + br + sdiv + br 245 Ref = getInstCost(Branch) + 246 getInstCost(Sub, /*SizeOnly =*/ true) + 247 getInstCost(BrBB1, /*SizeOnly =*/ true) + 248 getInstCost(Sdiv, /*SizeOnly =*/ true) + 249 getInstCost(BrLoop, /*SizeOnly =*/ true); 250 Test = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor); 251 EXPECT_EQ(Test, Ref); 252 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 253 } 254 255 TEST_F(FunctionSpecializationTest, Misc) { 256 const char *ModuleString = R"( 257 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] } 258 @g = constant %struct_t zeroinitializer, align 16 259 260 declare i32 @llvm.smax.i32(i32, i32) 261 declare i32 @bar(i32) 262 263 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) { 264 %cmp = icmp eq i8 %a, 10 265 %ext = zext i1 %cmp to i64 266 %sel = select i1 %cond, i64 %ext, i64 1 267 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4 268 %ld = load i32, ptr %gep 269 %fr = freeze i32 %ld 270 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1) 271 %call = call i32 @bar(i32 %smax) 272 %fr2 = freeze i32 %c 273 %add = add i32 %call, %fr2 274 ret i32 %add 275 } 276 )"; 277 278 Module &M = parseModule(ModuleString); 279 Function *F = M.getFunction("foo"); 280 FunctionSpecializer Specializer = getSpecializerFor(F); 281 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 282 283 GlobalVariable *GV = M.getGlobalVariable("g"); 284 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 285 Constant *True = ConstantInt::getTrue(M.getContext()); 286 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext())); 287 288 auto BlockIter = F->front().begin(); 289 Instruction &Icmp = *BlockIter++; 290 Instruction &Zext = *BlockIter++; 291 Instruction &Select = *BlockIter++; 292 Instruction &Gep = *BlockIter++; 293 Instruction &Load = *BlockIter++; 294 Instruction &Freeze = *BlockIter++; 295 Instruction &Smax = *BlockIter++; 296 297 // icmp + zext 298 Bonus Ref = getInstCost(Icmp) + getInstCost(Zext); 299 Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 300 EXPECT_EQ(Test, Ref); 301 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 302 303 // select 304 Ref = getInstCost(Select); 305 Test = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor); 306 EXPECT_EQ(Test, Ref); 307 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 308 309 // gep + load + freeze + smax 310 Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) + 311 getInstCost(Smax); 312 Test = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor); 313 EXPECT_EQ(Test, Ref); 314 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 315 316 Test = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor); 317 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 318 } 319 320 TEST_F(FunctionSpecializationTest, PhiNode) { 321 const char *ModuleString = R"( 322 define void @foo(i32 %a, i32 %b, i32 %i) { 323 entry: 324 br label %loop 325 loop: 326 %0 = phi i32 [ %a, %entry ], [ %3, %bb ] 327 switch i32 %i, label %default 328 [ i32 1, label %case1 329 i32 2, label %case2 ] 330 case1: 331 %1 = add i32 %0, 1 332 br label %bb 333 case2: 334 %2 = phi i32 [ %a, %entry ], [ %0, %loop ] 335 br label %bb 336 bb: 337 %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ] 338 %4 = icmp eq i32 %3, 1 339 br i1 %4, label %bb, label %loop 340 default: 341 ret void 342 } 343 )"; 344 345 Module &M = parseModule(ModuleString); 346 Function *F = M.getFunction("foo"); 347 FunctionSpecializer Specializer = getSpecializerFor(F); 348 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 349 350 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 351 352 auto FuncIter = F->begin(); 353 BasicBlock &Loop = *++FuncIter; 354 BasicBlock &Case1 = *++FuncIter; 355 BasicBlock &Case2 = *++FuncIter; 356 BasicBlock &BB = *++FuncIter; 357 358 Instruction &PhiLoop = Loop.front(); 359 Instruction &Switch = Loop.back(); 360 Instruction &Add = Case1.front(); 361 Instruction &PhiCase2 = Case2.front(); 362 Instruction &BrBB = Case2.back(); 363 Instruction &PhiBB = BB.front(); 364 Instruction &Icmp = *++BB.begin(); 365 Instruction &Branch = BB.back(); 366 367 Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 368 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 369 370 Test = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 371 EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0); 372 373 // switch + phi + br 374 Bonus Ref = getInstCost(Switch) + 375 getInstCost(PhiCase2, /*SizeOnly =*/ true) + 376 getInstCost(BrBB, /*SizeOnly =*/ true); 377 Test = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor); 378 EXPECT_EQ(Test, Ref); 379 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 380 381 // phi + phi + add + icmp + branch 382 Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) + 383 getInstCost(Icmp) + getInstCost(Branch); 384 Test = Visitor.getBonusFromPendingPHIs(); 385 EXPECT_EQ(Test, Ref); 386 EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0); 387 } 388 389