1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/AssumptionCache.h" 10 #include "llvm/Analysis/BlockFrequencyInfo.h" 11 #include "llvm/Analysis/BranchProbabilityInfo.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/PostDominators.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/AsmParser/Parser.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "llvm/Transforms/IPO/FunctionSpecialization.h" 20 #include "llvm/Transforms/Utils/SCCPSolver.h" 21 #include "gtest/gtest.h" 22 #include <memory> 23 24 namespace llvm { 25 26 static void removeSSACopy(Function &F) { 27 for (BasicBlock &BB : F) { 28 for (Instruction &Inst : llvm::make_early_inc_range(BB)) { 29 if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { 30 if (II->getIntrinsicID() != Intrinsic::ssa_copy) 31 continue; 32 Inst.replaceAllUsesWith(II->getOperand(0)); 33 Inst.eraseFromParent(); 34 } 35 } 36 } 37 } 38 39 class FunctionSpecializationTest : public testing::Test { 40 protected: 41 LLVMContext Ctx; 42 FunctionAnalysisManager FAM; 43 std::unique_ptr<Module> M; 44 std::unique_ptr<SCCPSolver> Solver; 45 46 FunctionSpecializationTest() { 47 FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 48 FAM.registerPass([&] { return TargetIRAnalysis(); }); 49 FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 50 FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 51 FAM.registerPass([&] { return LoopAnalysis(); }); 52 FAM.registerPass([&] { return AssumptionAnalysis(); }); 53 FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 54 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 55 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 56 } 57 58 Module &parseModule(const char *ModuleString) { 59 SMDiagnostic Err; 60 M = parseAssemblyString(ModuleString, Err, Ctx); 61 EXPECT_TRUE(M); 62 return *M; 63 } 64 65 FunctionSpecializer getSpecializerFor(Function *F) { 66 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 67 return FAM.getResult<TargetLibraryAnalysis>(F); 68 }; 69 auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 70 return FAM.getResult<TargetIRAnalysis>(F); 71 }; 72 auto GetAC = [this](Function &F) -> AssumptionCache & { 73 return FAM.getResult<AssumptionAnalysis>(F); 74 }; 75 auto GetDT = [this](Function &F) -> DominatorTree & { 76 return FAM.getResult<DominatorTreeAnalysis>(F); 77 }; 78 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 79 return FAM.getResult<BlockFrequencyAnalysis>(F); 80 }; 81 82 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 83 84 DominatorTree &DT = GetDT(*F); 85 AssumptionCache &AC = GetAC(*F); 86 Solver->addPredicateInfo(*F, DT, AC); 87 88 Solver->markBlockExecutable(&F->front()); 89 for (Argument &Arg : F->args()) 90 Solver->markOverdefined(&Arg); 91 Solver->solveWhileResolvedUndefsIn(*M); 92 93 removeSSACopy(*F); 94 95 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 96 GetAC); 97 } 98 99 Cost getInstCost(Instruction &I) { 100 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 101 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction()); 102 103 return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() * 104 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); 105 } 106 }; 107 108 } // namespace llvm 109 110 using namespace llvm; 111 112 TEST_F(FunctionSpecializationTest, SwitchInst) { 113 const char *ModuleString = R"( 114 define void @foo(i32 %a, i32 %b, i32 %i) { 115 entry: 116 br label %loop 117 loop: 118 switch i32 %i, label %default 119 [ i32 1, label %case1 120 i32 2, label %case2 ] 121 case1: 122 %0 = mul i32 %a, 2 123 %1 = sub i32 6, 5 124 br label %bb1 125 case2: 126 %2 = and i32 %b, 3 127 %3 = sdiv i32 8, 2 128 br label %bb2 129 bb1: 130 %4 = add i32 %0, %b 131 br label %loop 132 bb2: 133 %5 = or i32 %2, %a 134 br label %loop 135 default: 136 ret void 137 } 138 )"; 139 140 Module &M = parseModule(ModuleString); 141 Function *F = M.getFunction("foo"); 142 FunctionSpecializer Specializer = getSpecializerFor(F); 143 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 144 145 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 146 147 auto FuncIter = F->begin(); 148 ++FuncIter; 149 BasicBlock &Case1 = *++FuncIter; 150 BasicBlock &Case2 = *++FuncIter; 151 BasicBlock &BB1 = *++FuncIter; 152 BasicBlock &BB2 = *++FuncIter; 153 154 Instruction &Mul = Case1.front(); 155 Instruction &And = Case2.front(); 156 Instruction &Sdiv = *++Case2.begin(); 157 Instruction &BrBB2 = Case2.back(); 158 Instruction &Add = BB1.front(); 159 Instruction &Or = BB2.front(); 160 Instruction &BrLoop = BB2.back(); 161 162 // mul 163 Cost Ref = getInstCost(Mul); 164 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 165 EXPECT_EQ(Bonus, Ref); 166 EXPECT_TRUE(Bonus > 0); 167 168 // and + or + add 169 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add); 170 Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 171 EXPECT_EQ(Bonus, Ref); 172 EXPECT_TRUE(Bonus > 0); 173 174 // sdiv + br + br 175 Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrLoop); 176 Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor); 177 EXPECT_EQ(Bonus, Ref); 178 EXPECT_TRUE(Bonus > 0); 179 } 180 181 TEST_F(FunctionSpecializationTest, BranchInst) { 182 const char *ModuleString = R"( 183 define void @foo(i32 %a, i32 %b, i1 %cond) { 184 entry: 185 br label %loop 186 loop: 187 br i1 %cond, label %bb0, label %bb2 188 bb0: 189 %0 = mul i32 %a, 2 190 %1 = sub i32 6, 5 191 br label %bb1 192 bb1: 193 %2 = add i32 %0, %b 194 %3 = sdiv i32 8, 2 195 br label %loop 196 bb2: 197 ret void 198 } 199 )"; 200 201 Module &M = parseModule(ModuleString); 202 Function *F = M.getFunction("foo"); 203 FunctionSpecializer Specializer = getSpecializerFor(F); 204 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 205 206 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 207 Constant *False = ConstantInt::getFalse(M.getContext()); 208 209 auto FuncIter = F->begin(); 210 ++FuncIter; 211 BasicBlock &BB0 = *++FuncIter; 212 BasicBlock &BB1 = *++FuncIter; 213 214 Instruction &Mul = BB0.front(); 215 Instruction &Sub = *++BB0.begin(); 216 Instruction &BrBB1 = BB0.back(); 217 Instruction &Add = BB1.front(); 218 Instruction &Sdiv = *++BB1.begin(); 219 Instruction &BrLoop = BB1.back(); 220 221 // mul 222 Cost Ref = getInstCost(Mul); 223 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 224 EXPECT_EQ(Bonus, Ref); 225 EXPECT_TRUE(Bonus > 0); 226 227 // add 228 Ref = getInstCost(Add); 229 Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 230 EXPECT_EQ(Bonus, Ref); 231 EXPECT_TRUE(Bonus > 0); 232 233 // sub + br + sdiv + br 234 Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) + 235 getInstCost(BrLoop); 236 Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor); 237 EXPECT_EQ(Bonus, Ref); 238 EXPECT_TRUE(Bonus > 0); 239 } 240 241 TEST_F(FunctionSpecializationTest, Misc) { 242 const char *ModuleString = R"( 243 %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] } 244 @g = constant %struct_t zeroinitializer, align 16 245 246 declare i32 @llvm.smax.i32(i32, i32) 247 declare i32 @bar(i32) 248 249 define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) { 250 %cmp = icmp eq i8 %a, 10 251 %ext = zext i1 %cmp to i64 252 %sel = select i1 %cond, i64 %ext, i64 1 253 %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4 254 %ld = load i32, ptr %gep 255 %fr = freeze i32 %ld 256 %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1) 257 %call = call i32 @bar(i32 %smax) 258 %fr2 = freeze i32 %c 259 %add = add i32 %call, %fr2 260 ret i32 %add 261 } 262 )"; 263 264 Module &M = parseModule(ModuleString); 265 Function *F = M.getFunction("foo"); 266 FunctionSpecializer Specializer = getSpecializerFor(F); 267 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 268 269 GlobalVariable *GV = M.getGlobalVariable("g"); 270 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 271 Constant *True = ConstantInt::getTrue(M.getContext()); 272 Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext())); 273 274 auto BlockIter = F->front().begin(); 275 Instruction &Icmp = *BlockIter++; 276 Instruction &Zext = *BlockIter++; 277 Instruction &Select = *BlockIter++; 278 Instruction &Gep = *BlockIter++; 279 Instruction &Load = *BlockIter++; 280 Instruction &Freeze = *BlockIter++; 281 Instruction &Smax = *BlockIter++; 282 283 // icmp + zext 284 Cost Ref = getInstCost(Icmp) + getInstCost(Zext); 285 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 286 EXPECT_EQ(Bonus, Ref); 287 EXPECT_TRUE(Bonus > 0); 288 289 // select 290 Ref = getInstCost(Select); 291 Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor); 292 EXPECT_EQ(Bonus, Ref); 293 EXPECT_TRUE(Bonus > 0); 294 295 // gep + load + freeze + smax 296 Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) + 297 getInstCost(Smax); 298 Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor); 299 EXPECT_EQ(Bonus, Ref); 300 EXPECT_TRUE(Bonus > 0); 301 302 Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor); 303 EXPECT_TRUE(Bonus == 0); 304 } 305