xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision d1b376fd7bf73bca557f3c174d4c129ed4d45ae5)
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
20 #include "llvm/Transforms/Utils/SCCPSolver.h"
21 #include "gtest/gtest.h"
22 #include <memory>
23 
24 namespace llvm {
25 
26 static void removeSSACopy(Function &F) {
27   for (BasicBlock &BB : F) {
28     for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
29       if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
30         if (II->getIntrinsicID() != Intrinsic::ssa_copy)
31           continue;
32         Inst.replaceAllUsesWith(II->getOperand(0));
33         Inst.eraseFromParent();
34       }
35     }
36   }
37 }
38 
39 class FunctionSpecializationTest : public testing::Test {
40 protected:
41   LLVMContext Ctx;
42   FunctionAnalysisManager FAM;
43   std::unique_ptr<Module> M;
44   std::unique_ptr<SCCPSolver> Solver;
45 
46   FunctionSpecializationTest() {
47     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
48     FAM.registerPass([&] { return TargetIRAnalysis(); });
49     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
50     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
51     FAM.registerPass([&] { return LoopAnalysis(); });
52     FAM.registerPass([&] { return AssumptionAnalysis(); });
53     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
54     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
55     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
56   }
57 
58   Module &parseModule(const char *ModuleString) {
59     SMDiagnostic Err;
60     M = parseAssemblyString(ModuleString, Err, Ctx);
61     EXPECT_TRUE(M);
62     return *M;
63   }
64 
65   FunctionSpecializer getSpecializerFor(Function *F) {
66     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
67       return FAM.getResult<TargetLibraryAnalysis>(F);
68     };
69     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
70       return FAM.getResult<TargetIRAnalysis>(F);
71     };
72     auto GetAC = [this](Function &F) -> AssumptionCache & {
73       return FAM.getResult<AssumptionAnalysis>(F);
74     };
75     auto GetDT = [this](Function &F) -> DominatorTree & {
76       return FAM.getResult<DominatorTreeAnalysis>(F);
77     };
78     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
79       return FAM.getResult<BlockFrequencyAnalysis>(F);
80     };
81 
82     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
83 
84     DominatorTree &DT = GetDT(*F);
85     AssumptionCache &AC = GetAC(*F);
86     Solver->addPredicateInfo(*F, DT, AC);
87 
88     Solver->markBlockExecutable(&F->front());
89     for (Argument &Arg : F->args())
90       Solver->markOverdefined(&Arg);
91     Solver->solveWhileResolvedUndefsIn(*M);
92 
93     removeSSACopy(*F);
94 
95     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
96                                GetAC);
97   }
98 
99   Bonus getInstCost(Instruction &I, bool SizeOnly = false) {
100     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
101     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
102 
103     Cost CodeSize =
104         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
105 
106     Cost Latency = SizeOnly ? 0 :
107         BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
108         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency);
109 
110     return {CodeSize, Latency};
111   }
112 };
113 
114 } // namespace llvm
115 
116 using namespace llvm;
117 
118 TEST_F(FunctionSpecializationTest, SwitchInst) {
119   const char *ModuleString = R"(
120     define void @foo(i32 %a, i32 %b, i32 %i) {
121     entry:
122       br label %loop
123     loop:
124       switch i32 %i, label %default
125       [ i32 1, label %case1
126         i32 2, label %case2 ]
127     case1:
128       %0 = mul i32 %a, 2
129       %1 = sub i32 6, 5
130       br label %bb1
131     case2:
132       %2 = and i32 %b, 3
133       %3 = sdiv i32 8, 2
134       br label %bb2
135     bb1:
136       %4 = add i32 %0, %b
137       br label %loop
138     bb2:
139       %5 = or i32 %2, %a
140       br label %loop
141     default:
142       ret void
143     }
144   )";
145 
146   Module &M = parseModule(ModuleString);
147   Function *F = M.getFunction("foo");
148   FunctionSpecializer Specializer = getSpecializerFor(F);
149   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
150 
151   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
152 
153   auto FuncIter = F->begin();
154   BasicBlock &Loop = *++FuncIter;
155   BasicBlock &Case1 = *++FuncIter;
156   BasicBlock &Case2 = *++FuncIter;
157   BasicBlock &BB1 = *++FuncIter;
158   BasicBlock &BB2 = *++FuncIter;
159 
160   Instruction &Switch = Loop.front();
161   Instruction &Mul = Case1.front();
162   Instruction &And = Case2.front();
163   Instruction &Sdiv = *++Case2.begin();
164   Instruction &BrBB2 = Case2.back();
165   Instruction &Add = BB1.front();
166   Instruction &Or = BB2.front();
167   Instruction &BrLoop = BB2.back();
168 
169   // mul
170   Bonus Ref = getInstCost(Mul);
171   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
172   EXPECT_EQ(Test, Ref);
173   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
174 
175   // and + or + add
176   Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
177   Test = Visitor.getSpecializationBonus(F->getArg(1), One);
178   EXPECT_EQ(Test, Ref);
179   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
180 
181   // switch + sdiv + br + br
182   Ref = getInstCost(Switch) +
183         getInstCost(Sdiv, /*SizeOnly =*/ true) +
184         getInstCost(BrBB2, /*SizeOnly =*/ true) +
185         getInstCost(BrLoop, /*SizeOnly =*/ true);
186   Test = Visitor.getSpecializationBonus(F->getArg(2), One);
187   EXPECT_EQ(Test, Ref);
188   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
189 }
190 
191 TEST_F(FunctionSpecializationTest, BranchInst) {
192   const char *ModuleString = R"(
193     define void @foo(i32 %a, i32 %b, i1 %cond) {
194     entry:
195       br label %loop
196     loop:
197       br i1 %cond, label %bb0, label %bb3
198     bb0:
199       %0 = mul i32 %a, 2
200       %1 = sub i32 6, 5
201       br i1 %cond, label %bb1, label %bb2
202     bb1:
203       %2 = add i32 %0, %b
204       %3 = sdiv i32 8, 2
205       br label %bb2
206     bb2:
207       br label %loop
208     bb3:
209       ret void
210     }
211   )";
212 
213   Module &M = parseModule(ModuleString);
214   Function *F = M.getFunction("foo");
215   FunctionSpecializer Specializer = getSpecializerFor(F);
216   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
217 
218   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
219   Constant *False = ConstantInt::getFalse(M.getContext());
220 
221   auto FuncIter = F->begin();
222   BasicBlock &Loop = *++FuncIter;
223   BasicBlock &BB0 = *++FuncIter;
224   BasicBlock &BB1 = *++FuncIter;
225   BasicBlock &BB2 = *++FuncIter;
226 
227   Instruction &Branch = Loop.front();
228   Instruction &Mul = BB0.front();
229   Instruction &Sub = *++BB0.begin();
230   Instruction &BrBB1BB2 = BB0.back();
231   Instruction &Add = BB1.front();
232   Instruction &Sdiv = *++BB1.begin();
233   Instruction &BrBB2 = BB1.back();
234   Instruction &BrLoop = BB2.front();
235 
236   // mul
237   Bonus Ref = getInstCost(Mul);
238   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
239   EXPECT_EQ(Test, Ref);
240   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
241 
242   // add
243   Ref = getInstCost(Add);
244   Test = Visitor.getSpecializationBonus(F->getArg(1), One);
245   EXPECT_EQ(Test, Ref);
246   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
247 
248   // branch + sub + br + sdiv + br
249   Ref = getInstCost(Branch) +
250         getInstCost(Sub, /*SizeOnly =*/ true) +
251         getInstCost(BrBB1BB2) +
252         getInstCost(Sdiv, /*SizeOnly =*/ true) +
253         getInstCost(BrBB2, /*SizeOnly =*/ true) +
254         getInstCost(BrLoop, /*SizeOnly =*/ true);
255   Test = Visitor.getSpecializationBonus(F->getArg(2), False);
256   EXPECT_EQ(Test, Ref);
257   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
258 }
259 
260 TEST_F(FunctionSpecializationTest, Misc) {
261   const char *ModuleString = R"(
262     %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
263     @g = constant %struct_t zeroinitializer, align 16
264 
265     declare i32 @llvm.smax.i32(i32, i32)
266     declare i32 @bar(i32)
267 
268     define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
269       %cmp = icmp eq i8 %a, 10
270       %ext = zext i1 %cmp to i64
271       %sel = select i1 %cond, i64 %ext, i64 1
272       %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
273       %ld = load i32, ptr %gep
274       %fr = freeze i32 %ld
275       %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
276       %call = call i32 @bar(i32 %smax)
277       %fr2 = freeze i32 %c
278       %add = add i32 %call, %fr2
279       ret i32 %add
280     }
281   )";
282 
283   Module &M = parseModule(ModuleString);
284   Function *F = M.getFunction("foo");
285   FunctionSpecializer Specializer = getSpecializerFor(F);
286   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
287 
288   GlobalVariable *GV = M.getGlobalVariable("g");
289   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
290   Constant *True = ConstantInt::getTrue(M.getContext());
291   Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
292 
293   auto BlockIter = F->front().begin();
294   Instruction &Icmp = *BlockIter++;
295   Instruction &Zext = *BlockIter++;
296   Instruction &Select = *BlockIter++;
297   Instruction &Gep = *BlockIter++;
298   Instruction &Load = *BlockIter++;
299   Instruction &Freeze = *BlockIter++;
300   Instruction &Smax = *BlockIter++;
301 
302   // icmp + zext
303   Bonus Ref = getInstCost(Icmp) + getInstCost(Zext);
304   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
305   EXPECT_EQ(Test, Ref);
306   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
307 
308   // select
309   Ref = getInstCost(Select);
310   Test = Visitor.getSpecializationBonus(F->getArg(1), True);
311   EXPECT_EQ(Test, Ref);
312   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
313 
314   // gep + load + freeze + smax
315   Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
316         getInstCost(Smax);
317   Test = Visitor.getSpecializationBonus(F->getArg(2), GV);
318   EXPECT_EQ(Test, Ref);
319   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
320 
321   Test = Visitor.getSpecializationBonus(F->getArg(3), Undef);
322   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
323 }
324 
325 TEST_F(FunctionSpecializationTest, PhiNode) {
326   const char *ModuleString = R"(
327     define void @foo(i32 %a, i32 %b, i32 %i) {
328     entry:
329       br label %loop
330     loop:
331       %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
332       switch i32 %i, label %default
333       [ i32 1, label %case1
334         i32 2, label %case2 ]
335     case1:
336       %1 = add i32 %0, 1
337       br label %bb
338     case2:
339       %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
340       br label %bb
341     bb:
342       %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
343       %4 = icmp eq i32 %3, 1
344       br i1 %4, label %bb, label %loop
345     default:
346       ret void
347     }
348   )";
349 
350   Module &M = parseModule(ModuleString);
351   Function *F = M.getFunction("foo");
352   FunctionSpecializer Specializer = getSpecializerFor(F);
353   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
354 
355   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
356 
357   auto FuncIter = F->begin();
358   BasicBlock &Loop = *++FuncIter;
359   BasicBlock &Case1 = *++FuncIter;
360   BasicBlock &Case2 = *++FuncIter;
361   BasicBlock &BB = *++FuncIter;
362 
363   Instruction &PhiLoop = Loop.front();
364   Instruction &Switch = Loop.back();
365   Instruction &Add = Case1.front();
366   Instruction &PhiCase2 = Case2.front();
367   Instruction &BrBB = Case2.back();
368   Instruction &PhiBB = BB.front();
369   Instruction &Icmp = *++BB.begin();
370   Instruction &Branch = BB.back();
371 
372   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
373   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
374 
375   Test = Visitor.getSpecializationBonus(F->getArg(1), One);
376   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
377 
378   // switch + phi + br
379   Bonus Ref = getInstCost(Switch) +
380               getInstCost(PhiCase2, /*SizeOnly =*/ true) +
381               getInstCost(BrBB, /*SizeOnly =*/ true);
382   Test = Visitor.getSpecializationBonus(F->getArg(2), One);
383   EXPECT_EQ(Test, Ref);
384   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
385 
386   // phi + phi + add + icmp + branch
387   Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) +
388         getInstCost(Icmp) + getInstCost(Branch);
389   Test = Visitor.getBonusFromPendingPHIs();
390   EXPECT_EQ(Test, Ref);
391   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
392 }
393 
394