xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision 36c6632eb43bf67e19c8a6a21981cf66e06389b4)
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
21 #include "llvm/Transforms/Utils/SCCPSolver.h"
22 #include "gtest/gtest.h"
23 #include <memory>
24 
25 namespace llvm {
26 
27 static void removeSSACopy(Function &F) {
28   for (BasicBlock &BB : F) {
29     for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
30       if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
31         if (II->getIntrinsicID() != Intrinsic::ssa_copy)
32           continue;
33         Inst.replaceAllUsesWith(II->getOperand(0));
34         Inst.eraseFromParent();
35       }
36     }
37   }
38 }
39 
40 class FunctionSpecializationTest : public testing::Test {
41 protected:
42   LLVMContext Ctx;
43   FunctionAnalysisManager FAM;
44   std::unique_ptr<Module> M;
45   std::unique_ptr<SCCPSolver> Solver;
46 
47   FunctionSpecializationTest() {
48     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
49     FAM.registerPass([&] { return TargetIRAnalysis(); });
50     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
51     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
52     FAM.registerPass([&] { return LoopAnalysis(); });
53     FAM.registerPass([&] { return AssumptionAnalysis(); });
54     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
55     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
56     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
57   }
58 
59   Module &parseModule(const char *ModuleString) {
60     SMDiagnostic Err;
61     M = parseAssemblyString(ModuleString, Err, Ctx);
62     EXPECT_TRUE(M);
63     return *M;
64   }
65 
66   FunctionSpecializer getSpecializerFor(Function *F) {
67     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
68       return FAM.getResult<TargetLibraryAnalysis>(F);
69     };
70     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
71       return FAM.getResult<TargetIRAnalysis>(F);
72     };
73     auto GetAC = [this](Function &F) -> AssumptionCache & {
74       return FAM.getResult<AssumptionAnalysis>(F);
75     };
76     auto GetDT = [this](Function &F) -> DominatorTree & {
77       return FAM.getResult<DominatorTreeAnalysis>(F);
78     };
79     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
80       return FAM.getResult<BlockFrequencyAnalysis>(F);
81     };
82 
83     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
84 
85     DominatorTree &DT = GetDT(*F);
86     AssumptionCache &AC = GetAC(*F);
87     Solver->addPredicateInfo(*F, DT, AC);
88 
89     Solver->markBlockExecutable(&F->front());
90     for (Argument &Arg : F->args())
91       Solver->markOverdefined(&Arg);
92     Solver->solveWhileResolvedUndefsIn(*M);
93 
94     removeSSACopy(*F);
95 
96     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
97                                GetAC);
98   }
99 
100   Bonus getInstCost(Instruction &I, bool SizeOnly = false) {
101     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
102     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
103 
104     Cost CodeSize =
105         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
106 
107     Cost Latency =
108         SizeOnly
109             ? 0
110             : BFI.getBlockFreq(I.getParent()).getFrequency() /
111                   BFI.getEntryFreq().getFrequency() *
112                   TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency);
113 
114     return {CodeSize, Latency};
115   }
116 };
117 
118 } // namespace llvm
119 
120 using namespace llvm;
121 
122 TEST_F(FunctionSpecializationTest, SwitchInst) {
123   const char *ModuleString = R"(
124     define void @foo(i32 %a, i32 %b, i32 %i) {
125     entry:
126       br label %loop
127     loop:
128       switch i32 %i, label %default
129       [ i32 1, label %case1
130         i32 2, label %case2 ]
131     case1:
132       %0 = mul i32 %a, 2
133       %1 = sub i32 6, 5
134       br label %bb1
135     case2:
136       %2 = and i32 %b, 3
137       %3 = sdiv i32 8, 2
138       br label %bb2
139     bb1:
140       %4 = add i32 %0, %b
141       br label %loop
142     bb2:
143       %5 = or i32 %2, %a
144       br label %loop
145     default:
146       ret void
147     }
148   )";
149 
150   Module &M = parseModule(ModuleString);
151   Function *F = M.getFunction("foo");
152   FunctionSpecializer Specializer = getSpecializerFor(F);
153   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
154 
155   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
156 
157   auto FuncIter = F->begin();
158   BasicBlock &Loop = *++FuncIter;
159   BasicBlock &Case1 = *++FuncIter;
160   BasicBlock &Case2 = *++FuncIter;
161   BasicBlock &BB1 = *++FuncIter;
162   BasicBlock &BB2 = *++FuncIter;
163 
164   Instruction &Switch = Loop.front();
165   Instruction &Mul = Case1.front();
166   Instruction &And = Case2.front();
167   Instruction &Sdiv = *++Case2.begin();
168   Instruction &BrBB2 = Case2.back();
169   Instruction &Add = BB1.front();
170   Instruction &Or = BB2.front();
171   Instruction &BrLoop = BB2.back();
172 
173   // mul
174   Bonus Ref = getInstCost(Mul);
175   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
176   EXPECT_EQ(Test, Ref);
177   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
178 
179   // and + or + add
180   Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
181   Test = Visitor.getSpecializationBonus(F->getArg(1), One);
182   EXPECT_EQ(Test, Ref);
183   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
184 
185   // switch + sdiv + br + br
186   Ref = getInstCost(Switch) +
187         getInstCost(Sdiv, /*SizeOnly =*/ true) +
188         getInstCost(BrBB2, /*SizeOnly =*/ true) +
189         getInstCost(BrLoop, /*SizeOnly =*/ true);
190   Test = Visitor.getSpecializationBonus(F->getArg(2), One);
191   EXPECT_EQ(Test, Ref);
192   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
193 }
194 
195 TEST_F(FunctionSpecializationTest, BranchInst) {
196   const char *ModuleString = R"(
197     define void @foo(i32 %a, i32 %b, i1 %cond) {
198     entry:
199       br label %loop
200     loop:
201       br i1 %cond, label %bb0, label %bb3
202     bb0:
203       %0 = mul i32 %a, 2
204       %1 = sub i32 6, 5
205       br i1 %cond, label %bb1, label %bb2
206     bb1:
207       %2 = add i32 %0, %b
208       %3 = sdiv i32 8, 2
209       br label %bb2
210     bb2:
211       br label %loop
212     bb3:
213       ret void
214     }
215   )";
216 
217   Module &M = parseModule(ModuleString);
218   Function *F = M.getFunction("foo");
219   FunctionSpecializer Specializer = getSpecializerFor(F);
220   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
221 
222   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
223   Constant *False = ConstantInt::getFalse(M.getContext());
224 
225   auto FuncIter = F->begin();
226   BasicBlock &Loop = *++FuncIter;
227   BasicBlock &BB0 = *++FuncIter;
228   BasicBlock &BB1 = *++FuncIter;
229   BasicBlock &BB2 = *++FuncIter;
230 
231   Instruction &Branch = Loop.front();
232   Instruction &Mul = BB0.front();
233   Instruction &Sub = *++BB0.begin();
234   Instruction &BrBB1BB2 = BB0.back();
235   Instruction &Add = BB1.front();
236   Instruction &Sdiv = *++BB1.begin();
237   Instruction &BrBB2 = BB1.back();
238   Instruction &BrLoop = BB2.front();
239 
240   // mul
241   Bonus Ref = getInstCost(Mul);
242   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
243   EXPECT_EQ(Test, Ref);
244   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
245 
246   // add
247   Ref = getInstCost(Add);
248   Test = Visitor.getSpecializationBonus(F->getArg(1), One);
249   EXPECT_EQ(Test, Ref);
250   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
251 
252   // branch + sub + br + sdiv + br
253   Ref = getInstCost(Branch) +
254         getInstCost(Sub, /*SizeOnly =*/ true) +
255         getInstCost(BrBB1BB2) +
256         getInstCost(Sdiv, /*SizeOnly =*/ true) +
257         getInstCost(BrBB2, /*SizeOnly =*/ true) +
258         getInstCost(BrLoop, /*SizeOnly =*/ true);
259   Test = Visitor.getSpecializationBonus(F->getArg(2), False);
260   EXPECT_EQ(Test, Ref);
261   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
262 }
263 
264 TEST_F(FunctionSpecializationTest, Misc) {
265   const char *ModuleString = R"(
266     %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
267     @g = constant %struct_t zeroinitializer, align 16
268 
269     declare i32 @llvm.smax.i32(i32, i32)
270     declare i32 @bar(i32)
271 
272     define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
273       %cmp = icmp eq i8 %a, 10
274       %ext = zext i1 %cmp to i64
275       %sel = select i1 %cond, i64 %ext, i64 1
276       %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
277       %ld = load i32, ptr %gep
278       %fr = freeze i32 %ld
279       %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
280       %call = call i32 @bar(i32 %smax)
281       %fr2 = freeze i32 %c
282       %add = add i32 %call, %fr2
283       ret i32 %add
284     }
285   )";
286 
287   Module &M = parseModule(ModuleString);
288   Function *F = M.getFunction("foo");
289   FunctionSpecializer Specializer = getSpecializerFor(F);
290   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
291 
292   GlobalVariable *GV = M.getGlobalVariable("g");
293   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
294   Constant *True = ConstantInt::getTrue(M.getContext());
295   Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
296 
297   auto BlockIter = F->front().begin();
298   Instruction &Icmp = *BlockIter++;
299   Instruction &Zext = *BlockIter++;
300   Instruction &Select = *BlockIter++;
301   Instruction &Gep = *BlockIter++;
302   Instruction &Load = *BlockIter++;
303   Instruction &Freeze = *BlockIter++;
304   Instruction &Smax = *BlockIter++;
305 
306   // icmp + zext
307   Bonus Ref = getInstCost(Icmp) + getInstCost(Zext);
308   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
309   EXPECT_EQ(Test, Ref);
310   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
311 
312   // select
313   Ref = getInstCost(Select);
314   Test = Visitor.getSpecializationBonus(F->getArg(1), True);
315   EXPECT_EQ(Test, Ref);
316   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
317 
318   // gep + load + freeze + smax
319   Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
320         getInstCost(Smax);
321   Test = Visitor.getSpecializationBonus(F->getArg(2), GV);
322   EXPECT_EQ(Test, Ref);
323   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
324 
325   Test = Visitor.getSpecializationBonus(F->getArg(3), Undef);
326   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
327 }
328 
329 TEST_F(FunctionSpecializationTest, PhiNode) {
330   const char *ModuleString = R"(
331     define void @foo(i32 %a, i32 %b, i32 %i) {
332     entry:
333       br label %loop
334     loop:
335       %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
336       switch i32 %i, label %default
337       [ i32 1, label %case1
338         i32 2, label %case2 ]
339     case1:
340       %1 = add i32 %0, 1
341       br label %bb
342     case2:
343       %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
344       br label %bb
345     bb:
346       %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
347       %4 = icmp eq i32 %3, 1
348       br i1 %4, label %bb, label %loop
349     default:
350       ret void
351     }
352   )";
353 
354   Module &M = parseModule(ModuleString);
355   Function *F = M.getFunction("foo");
356   FunctionSpecializer Specializer = getSpecializerFor(F);
357   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
358 
359   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
360 
361   auto FuncIter = F->begin();
362   BasicBlock &Loop = *++FuncIter;
363   BasicBlock &Case1 = *++FuncIter;
364   BasicBlock &Case2 = *++FuncIter;
365   BasicBlock &BB = *++FuncIter;
366 
367   Instruction &PhiLoop = Loop.front();
368   Instruction &Switch = Loop.back();
369   Instruction &Add = Case1.front();
370   Instruction &PhiCase2 = Case2.front();
371   Instruction &BrBB = Case2.back();
372   Instruction &PhiBB = BB.front();
373   Instruction &Icmp = *++BB.begin();
374   Instruction &Branch = BB.back();
375 
376   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
377   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
378 
379   Test = Visitor.getSpecializationBonus(F->getArg(1), One);
380   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
381 
382   // switch + phi + br
383   Bonus Ref = getInstCost(Switch) +
384               getInstCost(PhiCase2, /*SizeOnly =*/ true) +
385               getInstCost(BrBB, /*SizeOnly =*/ true);
386   Test = Visitor.getSpecializationBonus(F->getArg(2), One);
387   EXPECT_EQ(Test, Ref);
388   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
389 
390   // phi + phi + add + icmp + branch
391   Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) +
392         getInstCost(Icmp) + getInstCost(Branch);
393   Test = Visitor.getBonusFromPendingPHIs();
394   EXPECT_EQ(Test, Ref);
395   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
396 }
397 
398