xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision 20c8f58c11d06d4a31fef5033ea5417fb99dfd0e)
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
20 #include "llvm/Transforms/Utils/SCCPSolver.h"
21 #include "gtest/gtest.h"
22 #include <memory>
23 
24 namespace llvm {
25 
26 class FunctionSpecializationTest : public testing::Test {
27 protected:
28   LLVMContext Ctx;
29   FunctionAnalysisManager FAM;
30   std::unique_ptr<Module> M;
31   std::unique_ptr<SCCPSolver> Solver;
32 
33   FunctionSpecializationTest() {
34     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
35     FAM.registerPass([&] { return TargetIRAnalysis(); });
36     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
37     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
38     FAM.registerPass([&] { return LoopAnalysis(); });
39     FAM.registerPass([&] { return AssumptionAnalysis(); });
40     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
41     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
42     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
43   }
44 
45   Module &parseModule(const char *ModuleString) {
46     SMDiagnostic Err;
47     M = parseAssemblyString(ModuleString, Err, Ctx);
48     EXPECT_TRUE(M);
49     return *M;
50   }
51 
52   FunctionSpecializer getSpecializerFor(Function *F) {
53     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
54       return FAM.getResult<TargetLibraryAnalysis>(F);
55     };
56     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
57       return FAM.getResult<TargetIRAnalysis>(F);
58     };
59     auto GetAC = [this](Function &F) -> AssumptionCache & {
60       return FAM.getResult<AssumptionAnalysis>(F);
61     };
62     auto GetDT = [this](Function &F) -> DominatorTree & {
63       return FAM.getResult<DominatorTreeAnalysis>(F);
64     };
65     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
66       return FAM.getResult<BlockFrequencyAnalysis>(F);
67     };
68 
69     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
70 
71     DominatorTree &DT = GetDT(*F);
72     AssumptionCache &AC = GetAC(*F);
73     Solver->addPredicateInfo(*F, DT, AC);
74 
75     Solver->markBlockExecutable(&F->front());
76     for (Argument &Arg : F->args())
77       Solver->markOverdefined(&Arg);
78     Solver->solveWhileResolvedUndefsIn(*M);
79 
80     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
81                                GetAC);
82   }
83 
84   Bonus getInstCost(Instruction &I, bool SizeOnly = false) {
85     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
86     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
87 
88     Cost CodeSize =
89         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
90 
91     Cost Latency = SizeOnly ? 0 :
92         BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
93         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency);
94 
95     return {CodeSize, Latency};
96   }
97 };
98 
99 } // namespace llvm
100 
101 using namespace llvm;
102 
103 TEST_F(FunctionSpecializationTest, SwitchInst) {
104   const char *ModuleString = R"(
105     define void @foo(i32 %a, i32 %b, i32 %i) {
106     entry:
107       br label %loop
108     loop:
109       switch i32 %i, label %default
110       [ i32 1, label %case1
111         i32 2, label %case2 ]
112     case1:
113       %0 = mul i32 %a, 2
114       %1 = sub i32 6, 5
115       br label %bb1
116     case2:
117       %2 = and i32 %b, 3
118       %3 = sdiv i32 8, 2
119       br label %bb2
120     bb1:
121       %4 = add i32 %0, %b
122       br label %loop
123     bb2:
124       %5 = or i32 %2, %a
125       br label %loop
126     default:
127       ret void
128     }
129   )";
130 
131   Module &M = parseModule(ModuleString);
132   Function *F = M.getFunction("foo");
133   FunctionSpecializer Specializer = getSpecializerFor(F);
134   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
135 
136   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
137 
138   auto FuncIter = F->begin();
139   BasicBlock &Loop = *++FuncIter;
140   BasicBlock &Case1 = *++FuncIter;
141   BasicBlock &Case2 = *++FuncIter;
142   BasicBlock &BB1 = *++FuncIter;
143   BasicBlock &BB2 = *++FuncIter;
144 
145   Instruction &Switch = Loop.front();
146   Instruction &Mul = Case1.front();
147   Instruction &And = Case2.front();
148   Instruction &Sdiv = *++Case2.begin();
149   Instruction &BrBB2 = Case2.back();
150   Instruction &Add = BB1.front();
151   Instruction &Or = BB2.front();
152   Instruction &BrLoop = BB2.back();
153 
154   // mul
155   Bonus Ref = getInstCost(Mul);
156   Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
157   EXPECT_EQ(Test, Ref);
158   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
159 
160   // and + or + add
161   Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
162   Test = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
163   EXPECT_EQ(Test, Ref);
164   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
165 
166   // switch + sdiv + br + br
167   Ref = getInstCost(Switch) +
168         getInstCost(Sdiv, /*SizeOnly =*/ true) +
169         getInstCost(BrBB2, /*SizeOnly =*/ true) +
170         getInstCost(BrLoop, /*SizeOnly =*/ true);
171   Test = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
172   EXPECT_EQ(Test, Ref);
173   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
174 }
175 
176 TEST_F(FunctionSpecializationTest, BranchInst) {
177   const char *ModuleString = R"(
178     define void @foo(i32 %a, i32 %b, i1 %cond) {
179     entry:
180       br label %loop
181     loop:
182       br i1 %cond, label %bb0, label %bb2
183     bb0:
184       %0 = mul i32 %a, 2
185       %1 = sub i32 6, 5
186       br label %bb1
187     bb1:
188       %2 = add i32 %0, %b
189       %3 = sdiv i32 8, 2
190       br label %loop
191     bb2:
192       ret void
193     }
194   )";
195 
196   Module &M = parseModule(ModuleString);
197   Function *F = M.getFunction("foo");
198   FunctionSpecializer Specializer = getSpecializerFor(F);
199   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
200 
201   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
202   Constant *False = ConstantInt::getFalse(M.getContext());
203 
204   auto FuncIter = F->begin();
205   BasicBlock &Loop = *++FuncIter;
206   BasicBlock &BB0 = *++FuncIter;
207   BasicBlock &BB1 = *++FuncIter;
208 
209   Instruction &Branch = Loop.front();
210   Instruction &Mul = BB0.front();
211   Instruction &Sub = *++BB0.begin();
212   Instruction &BrBB1 = BB0.back();
213   Instruction &Add = BB1.front();
214   Instruction &Sdiv = *++BB1.begin();
215   Instruction &BrLoop = BB1.back();
216 
217   // mul
218   Bonus Ref = getInstCost(Mul);
219   Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
220   EXPECT_EQ(Test, Ref);
221   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
222 
223   // add
224   Ref = getInstCost(Add);
225   Test = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
226   EXPECT_EQ(Test, Ref);
227   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
228 
229   // branch + sub + br + sdiv + br
230   Ref = getInstCost(Branch) +
231         getInstCost(Sub, /*SizeOnly =*/ true) +
232         getInstCost(BrBB1, /*SizeOnly =*/ true) +
233         getInstCost(Sdiv, /*SizeOnly =*/ true) +
234         getInstCost(BrLoop, /*SizeOnly =*/ true);
235   Test = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
236   EXPECT_EQ(Test, Ref);
237   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
238 }
239 
240 TEST_F(FunctionSpecializationTest, Misc) {
241   const char *ModuleString = R"(
242     %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
243     @g = constant %struct_t zeroinitializer, align 16
244 
245     declare i32 @llvm.smax.i32(i32, i32)
246     declare i32 @bar(i32)
247 
248     define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
249       %cmp = icmp eq i8 %a, 10
250       %ext = zext i1 %cmp to i64
251       %sel = select i1 %cond, i64 %ext, i64 1
252       %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
253       %ld = load i32, ptr %gep
254       %fr = freeze i32 %ld
255       %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
256       %call = call i32 @bar(i32 %smax)
257       %fr2 = freeze i32 %c
258       %add = add i32 %call, %fr2
259       ret i32 %add
260     }
261   )";
262 
263   Module &M = parseModule(ModuleString);
264   Function *F = M.getFunction("foo");
265   FunctionSpecializer Specializer = getSpecializerFor(F);
266   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
267 
268   GlobalVariable *GV = M.getGlobalVariable("g");
269   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
270   Constant *True = ConstantInt::getTrue(M.getContext());
271   Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
272 
273   auto BlockIter = F->front().begin();
274   Instruction &Icmp = *BlockIter++;
275   Instruction &Zext = *BlockIter++;
276   Instruction &Select = *BlockIter++;
277   Instruction &Gep = *BlockIter++;
278   Instruction &Load = *BlockIter++;
279   Instruction &Freeze = *BlockIter++;
280   Instruction &Smax = *BlockIter++;
281 
282   // icmp + zext
283   Bonus Ref = getInstCost(Icmp) + getInstCost(Zext);
284   Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
285   EXPECT_EQ(Test, Ref);
286   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
287 
288   // select
289   Ref = getInstCost(Select);
290   Test = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
291   EXPECT_EQ(Test, Ref);
292   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
293 
294   // gep + load + freeze + smax
295   Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
296         getInstCost(Smax);
297   Test = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
298   EXPECT_EQ(Test, Ref);
299   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
300 
301   Test = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
302   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
303 }
304 
305 TEST_F(FunctionSpecializationTest, PhiNode) {
306   const char *ModuleString = R"(
307     define void @foo(i32 %a, i32 %b, i32 %i) {
308     entry:
309       br label %loop
310     loop:
311       switch i32 %i, label %default
312       [ i32 1, label %case1
313         i32 2, label %case2 ]
314     case1:
315       %0 = add i32 %a, 1
316       br label %bb
317     case2:
318       %1 = sub i32 %b, 1
319       br label %bb
320     bb:
321       %2 = phi i32 [ %0, %case1 ], [ %1, %case2 ], [ %2, %bb ]
322       %3 = icmp eq i32 %2, 2
323       br i1 %3, label %bb, label %loop
324     default:
325       ret void
326     }
327   )";
328 
329   Module &M = parseModule(ModuleString);
330   Function *F = M.getFunction("foo");
331   FunctionSpecializer Specializer = getSpecializerFor(F);
332   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
333 
334   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
335 
336   auto FuncIter = F->begin();
337   for (int I = 0; I < 4; ++I)
338     ++FuncIter;
339 
340   BasicBlock &BB = *FuncIter;
341 
342   Instruction &Phi = BB.front();
343   Instruction &Icmp = *++BB.begin();
344   Instruction &Branch = BB.back();
345 
346   Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor) +
347                Specializer.getSpecializationBonus(F->getArg(1), One, Visitor) +
348                Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
349   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
350 
351   // phi + icmp + branch
352   Bonus Ref = getInstCost(Phi) + getInstCost(Icmp) + getInstCost(Branch);
353   Test = Visitor.getBonusFromPendingPHIs();
354   EXPECT_EQ(Test, Ref);
355   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
356 }
357 
358