1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/AssumptionCache.h" 10 #include "llvm/Analysis/BlockFrequencyInfo.h" 11 #include "llvm/Analysis/BranchProbabilityInfo.h" 12 #include "llvm/Analysis/LoopInfo.h" 13 #include "llvm/Analysis/PostDominators.h" 14 #include "llvm/Analysis/TargetLibraryInfo.h" 15 #include "llvm/Analysis/TargetTransformInfo.h" 16 #include "llvm/AsmParser/Parser.h" 17 #include "llvm/IR/Constants.h" 18 #include "llvm/Support/SourceMgr.h" 19 #include "llvm/Transforms/IPO/FunctionSpecialization.h" 20 #include "llvm/Transforms/Utils/SCCPSolver.h" 21 #include "gtest/gtest.h" 22 #include <memory> 23 24 namespace llvm { 25 26 class FunctionSpecializationTest : public testing::Test { 27 protected: 28 LLVMContext Ctx; 29 FunctionAnalysisManager FAM; 30 std::unique_ptr<Module> M; 31 std::unique_ptr<SCCPSolver> Solver; 32 33 FunctionSpecializationTest() { 34 FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 35 FAM.registerPass([&] { return TargetIRAnalysis(); }); 36 FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 37 FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 38 FAM.registerPass([&] { return LoopAnalysis(); }); 39 FAM.registerPass([&] { return AssumptionAnalysis(); }); 40 FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 41 FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 42 FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 43 } 44 45 Module &parseModule(const char *ModuleString) { 46 SMDiagnostic Err; 47 M = parseAssemblyString(ModuleString, Err, Ctx); 48 EXPECT_TRUE(M); 49 return *M; 50 } 51 52 FunctionSpecializer getSpecializerFor(Function *F) { 53 auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 54 return FAM.getResult<TargetLibraryAnalysis>(F); 55 }; 56 auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 57 return FAM.getResult<TargetIRAnalysis>(F); 58 }; 59 auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 60 return FAM.getResult<BlockFrequencyAnalysis>(F); 61 }; 62 auto GetAC = [this](Function &F) -> AssumptionCache & { 63 return FAM.getResult<AssumptionAnalysis>(F); 64 }; 65 auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn { 66 DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); 67 return { std::make_unique<PredicateInfo>(F, DT, 68 FAM.getResult<AssumptionAnalysis>(F)), 69 &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F) }; 70 }; 71 72 Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 73 74 Solver->addAnalysis(*F, GetAnalysis(*F)); 75 Solver->markBlockExecutable(&F->front()); 76 for (Argument &Arg : F->args()) 77 Solver->markOverdefined(&Arg); 78 Solver->solveWhileResolvedUndefsIn(*M); 79 80 return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 81 GetAC); 82 } 83 84 Cost getInstCost(Instruction &I) { 85 auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 86 auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction()); 87 88 return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() * 89 TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency); 90 } 91 }; 92 93 } // namespace llvm 94 95 using namespace llvm; 96 97 TEST_F(FunctionSpecializationTest, SwitchInst) { 98 const char *ModuleString = R"( 99 define void @foo(i32 %a, i32 %b, i32 %i) { 100 entry: 101 switch i32 %i, label %default 102 [ i32 1, label %case1 103 i32 2, label %case2 ] 104 case1: 105 %0 = mul i32 %a, 2 106 %1 = sub i32 6, 5 107 br label %bb1 108 case2: 109 %2 = and i32 %b, 3 110 %3 = sdiv i32 8, 2 111 br label %bb2 112 bb1: 113 %4 = add i32 %0, %b 114 br label %default 115 bb2: 116 %5 = or i32 %2, %a 117 br label %default 118 default: 119 ret void 120 } 121 )"; 122 123 Module &M = parseModule(ModuleString); 124 Function *F = M.getFunction("foo"); 125 FunctionSpecializer Specializer = getSpecializerFor(F); 126 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 127 128 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 129 130 auto FuncIter = F->begin(); 131 BasicBlock &Case1 = *++FuncIter; 132 BasicBlock &Case2 = *++FuncIter; 133 BasicBlock &BB1 = *++FuncIter; 134 BasicBlock &BB2 = *++FuncIter; 135 136 Instruction &Mul = Case1.front(); 137 Instruction &And = Case2.front(); 138 Instruction &Sdiv = *++Case2.begin(); 139 Instruction &BrBB2 = Case2.back(); 140 Instruction &Add = BB1.front(); 141 Instruction &Or = BB2.front(); 142 Instruction &BrDefault = BB2.back(); 143 144 // mul 145 Cost Ref = getInstCost(Mul); 146 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 147 EXPECT_EQ(Bonus, Ref); 148 149 // and + or + add 150 Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add); 151 Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 152 EXPECT_EQ(Bonus, Ref); 153 154 // sdiv + br + br 155 Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrDefault); 156 Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor); 157 EXPECT_EQ(Bonus, Ref); 158 } 159 160 TEST_F(FunctionSpecializationTest, BranchInst) { 161 const char *ModuleString = R"( 162 define void @foo(i32 %a, i32 %b, i1 %cond) { 163 entry: 164 br i1 %cond, label %bb0, label %bb2 165 bb0: 166 %0 = mul i32 %a, 2 167 %1 = sub i32 6, 5 168 br label %bb1 169 bb1: 170 %2 = add i32 %0, %b 171 %3 = sdiv i32 8, 2 172 br label %bb2 173 bb2: 174 ret void 175 } 176 )"; 177 178 Module &M = parseModule(ModuleString); 179 Function *F = M.getFunction("foo"); 180 FunctionSpecializer Specializer = getSpecializerFor(F); 181 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 182 183 Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 184 Constant *False = ConstantInt::getFalse(M.getContext()); 185 186 auto FuncIter = F->begin(); 187 BasicBlock &BB0 = *++FuncIter; 188 BasicBlock &BB1 = *++FuncIter; 189 190 Instruction &Mul = BB0.front(); 191 Instruction &Sub = *++BB0.begin(); 192 Instruction &BrBB1 = BB0.back(); 193 Instruction &Add = BB1.front(); 194 Instruction &Sdiv = *++BB1.begin(); 195 Instruction &BrBB2 = BB1.back(); 196 197 // mul 198 Cost Ref = getInstCost(Mul); 199 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 200 EXPECT_EQ(Bonus, Ref); 201 202 // add 203 Ref = getInstCost(Add); 204 Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor); 205 EXPECT_EQ(Bonus, Ref); 206 207 // sub + br + sdiv + br 208 Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) + 209 getInstCost(BrBB2); 210 Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor); 211 EXPECT_EQ(Bonus, Ref); 212 } 213 214 TEST_F(FunctionSpecializationTest, Misc) { 215 const char *ModuleString = R"( 216 @g = constant [2 x i32] zeroinitializer, align 4 217 218 define i32 @foo(i8 %a, i1 %cond, ptr %b) { 219 %cmp = icmp eq i8 %a, 10 220 %ext = zext i1 %cmp to i32 221 %sel = select i1 %cond, i32 %ext, i32 1 222 %gep = getelementptr i32, ptr %b, i32 %sel 223 %ld = load i32, ptr %gep 224 ret i32 %ld 225 } 226 )"; 227 228 Module &M = parseModule(ModuleString); 229 Function *F = M.getFunction("foo"); 230 FunctionSpecializer Specializer = getSpecializerFor(F); 231 InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 232 233 GlobalVariable *GV = M.getGlobalVariable("g"); 234 Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 235 Constant *True = ConstantInt::getTrue(M.getContext()); 236 237 auto BlockIter = F->front().begin(); 238 Instruction &Icmp = *BlockIter++; 239 Instruction &Zext = *BlockIter++; 240 Instruction &Select = *BlockIter++; 241 Instruction &Gep = *BlockIter++; 242 Instruction &Load = *BlockIter++; 243 244 // icmp + zext 245 Cost Ref = getInstCost(Icmp) + getInstCost(Zext); 246 Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor); 247 EXPECT_EQ(Bonus, Ref); 248 249 // select 250 Ref = getInstCost(Select); 251 Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor); 252 EXPECT_EQ(Bonus, Ref); 253 254 // gep + load 255 Ref = getInstCost(Gep) + getInstCost(Load); 256 Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor); 257 EXPECT_EQ(Bonus, Ref); 258 } 259