xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision 0524534d5220da5ecb2cd424a46520184d2be366)
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
20 #include "llvm/Transforms/Utils/SCCPSolver.h"
21 #include "gtest/gtest.h"
22 #include <memory>
23 
24 namespace llvm {
25 
26 class FunctionSpecializationTest : public testing::Test {
27 protected:
28   LLVMContext Ctx;
29   FunctionAnalysisManager FAM;
30   std::unique_ptr<Module> M;
31   std::unique_ptr<SCCPSolver> Solver;
32 
33   FunctionSpecializationTest() {
34     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
35     FAM.registerPass([&] { return TargetIRAnalysis(); });
36     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
37     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
38     FAM.registerPass([&] { return LoopAnalysis(); });
39     FAM.registerPass([&] { return AssumptionAnalysis(); });
40     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
41     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
42     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
43   }
44 
45   Module &parseModule(const char *ModuleString) {
46     SMDiagnostic Err;
47     M = parseAssemblyString(ModuleString, Err, Ctx);
48     EXPECT_TRUE(M);
49     return *M;
50   }
51 
52   FunctionSpecializer getSpecializerFor(Function *F) {
53     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
54       return FAM.getResult<TargetLibraryAnalysis>(F);
55     };
56     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
57       return FAM.getResult<TargetIRAnalysis>(F);
58     };
59     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
60       return FAM.getResult<BlockFrequencyAnalysis>(F);
61     };
62     auto GetAC = [this](Function &F) -> AssumptionCache & {
63       return FAM.getResult<AssumptionAnalysis>(F);
64     };
65     auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn {
66       DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
67       return { std::make_unique<PredicateInfo>(F, DT,
68                                 FAM.getResult<AssumptionAnalysis>(F)),
69                &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F) };
70     };
71 
72     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
73 
74     Solver->addAnalysis(*F, GetAnalysis(*F));
75     Solver->markBlockExecutable(&F->front());
76     for (Argument &Arg : F->args())
77       Solver->markOverdefined(&Arg);
78     Solver->solveWhileResolvedUndefsIn(*M);
79 
80     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
81                                GetAC);
82   }
83 
84   Cost getInstCost(Instruction &I) {
85     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
86     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
87 
88     uint64_t Weight = FunctionSpecializer::getBlockFreqMultiplier() *
89                       BFI.getBlockFreq(I.getParent()).getFrequency() /
90                       BFI.getEntryFreq();
91     return Weight *
92          TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
93   }
94 };
95 
96 } // namespace llvm
97 
98 using namespace llvm;
99 
100 TEST_F(FunctionSpecializationTest, SwitchInst) {
101   const char *ModuleString = R"(
102     define void @foo(i32 %a, i32 %b, i32 %i) {
103     entry:
104       switch i32 %i, label %default
105       [ i32 1, label %case1
106         i32 2, label %case2 ]
107     case1:
108       %0 = mul i32 %a, 2
109       %1 = sub i32 6, 5
110       br label %bb1
111     case2:
112       %2 = and i32 %b, 3
113       %3 = sdiv i32 8, 2
114       br label %bb2
115     bb1:
116       %4 = add i32 %0, %b
117       br label %default
118     bb2:
119       %5 = or i32 %2, %a
120       br label %default
121     default:
122       ret void
123     }
124   )";
125 
126   Module &M = parseModule(ModuleString);
127   Function *F = M.getFunction("foo");
128   FunctionSpecializer Specializer = getSpecializerFor(F);
129   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
130 
131   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
132 
133   auto FuncIter = F->begin();
134   BasicBlock &Case1 = *++FuncIter;
135   BasicBlock &Case2 = *++FuncIter;
136   BasicBlock &BB1 = *++FuncIter;
137   BasicBlock &BB2 = *++FuncIter;
138 
139   Instruction &Mul = Case1.front();
140   Instruction &And = Case2.front();
141   Instruction &Sdiv = *++Case2.begin();
142   Instruction &BrBB2 = Case2.back();
143   Instruction &Add = BB1.front();
144   Instruction &Or = BB2.front();
145   Instruction &BrDefault = BB2.back();
146 
147   // mul
148   Cost Ref = getInstCost(Mul);
149   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
150   EXPECT_EQ(Bonus, Ref);
151 
152   // and + or + add
153   Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
154   Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
155   EXPECT_EQ(Bonus, Ref);
156 
157   // sdiv + br + br
158   Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrDefault);
159   Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
160   EXPECT_EQ(Bonus, Ref);
161 }
162 
163 TEST_F(FunctionSpecializationTest, BranchInst) {
164   const char *ModuleString = R"(
165     define void @foo(i32 %a, i32 %b, i1 %cond) {
166     entry:
167       br i1 %cond, label %bb0, label %bb2
168     bb0:
169       %0 = mul i32 %a, 2
170       %1 = sub i32 6, 5
171       br label %bb1
172     bb1:
173       %2 = add i32 %0, %b
174       %3 = sdiv i32 8, 2
175       br label %bb2
176     bb2:
177       ret void
178     }
179   )";
180 
181   Module &M = parseModule(ModuleString);
182   Function *F = M.getFunction("foo");
183   FunctionSpecializer Specializer = getSpecializerFor(F);
184   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
185 
186   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
187   Constant *False = ConstantInt::getFalse(M.getContext());
188 
189   auto FuncIter = F->begin();
190   BasicBlock &BB0 = *++FuncIter;
191   BasicBlock &BB1 = *++FuncIter;
192 
193   Instruction &Mul = BB0.front();
194   Instruction &Sub = *++BB0.begin();
195   Instruction &BrBB1 = BB0.back();
196   Instruction &Add = BB1.front();
197   Instruction &Sdiv = *++BB1.begin();
198   Instruction &BrBB2 = BB1.back();
199 
200   // mul
201   Cost Ref = getInstCost(Mul);
202   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
203   EXPECT_EQ(Bonus, Ref);
204 
205   // add
206   Ref = getInstCost(Add);
207   Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
208   EXPECT_EQ(Bonus, Ref);
209 
210   // sub + br + sdiv + br
211   Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) +
212         getInstCost(BrBB2);
213   Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
214   EXPECT_EQ(Bonus, Ref);
215 }
216 
217 TEST_F(FunctionSpecializationTest, Misc) {
218   const char *ModuleString = R"(
219     @g = constant [2 x i32] zeroinitializer, align 4
220 
221     define i32 @foo(i8 %a, i1 %cond, ptr %b) {
222       %cmp = icmp eq i8 %a, 10
223       %ext = zext i1 %cmp to i32
224       %sel = select i1 %cond, i32 %ext, i32 1
225       %gep = getelementptr i32, ptr %b, i32 %sel
226       %ld = load i32, ptr %gep
227       ret i32 %ld
228     }
229   )";
230 
231   Module &M = parseModule(ModuleString);
232   Function *F = M.getFunction("foo");
233   FunctionSpecializer Specializer = getSpecializerFor(F);
234   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
235 
236   GlobalVariable *GV = M.getGlobalVariable("g");
237   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
238   Constant *True = ConstantInt::getTrue(M.getContext());
239 
240   auto BlockIter = F->front().begin();
241   Instruction &Icmp = *BlockIter++;
242   Instruction &Zext = *BlockIter++;
243   Instruction &Select = *BlockIter++;
244   Instruction &Gep = *BlockIter++;
245   Instruction &Load = *BlockIter++;
246 
247   // icmp + zext
248   Cost Ref = getInstCost(Icmp) + getInstCost(Zext);
249   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
250   EXPECT_EQ(Bonus, Ref);
251 
252   // select
253   Ref = getInstCost(Select);
254   Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
255   EXPECT_EQ(Bonus, Ref);
256 
257   // gep + load
258   Ref = getInstCost(Gep) + getInstCost(Load);
259   Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
260   EXPECT_EQ(Bonus, Ref);
261 }
262