xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision 5181156b3743df29dc840e15990d9202b3501f60)
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
20 #include "llvm/Transforms/Utils/SCCPSolver.h"
21 #include "gtest/gtest.h"
22 #include <memory>
23 
24 namespace llvm {
25 
26 static void removeSSACopy(Function &F) {
27   for (BasicBlock &BB : F) {
28     for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
29       if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
30         if (II->getIntrinsicID() != Intrinsic::ssa_copy)
31           continue;
32         Inst.replaceAllUsesWith(II->getOperand(0));
33         Inst.eraseFromParent();
34       }
35     }
36   }
37 }
38 
39 class FunctionSpecializationTest : public testing::Test {
40 protected:
41   LLVMContext Ctx;
42   FunctionAnalysisManager FAM;
43   std::unique_ptr<Module> M;
44   std::unique_ptr<SCCPSolver> Solver;
45 
46   FunctionSpecializationTest() {
47     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
48     FAM.registerPass([&] { return TargetIRAnalysis(); });
49     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
50     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
51     FAM.registerPass([&] { return LoopAnalysis(); });
52     FAM.registerPass([&] { return AssumptionAnalysis(); });
53     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
54     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
55     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
56   }
57 
58   Module &parseModule(const char *ModuleString) {
59     SMDiagnostic Err;
60     M = parseAssemblyString(ModuleString, Err, Ctx);
61     EXPECT_TRUE(M);
62     return *M;
63   }
64 
65   FunctionSpecializer getSpecializerFor(Function *F) {
66     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
67       return FAM.getResult<TargetLibraryAnalysis>(F);
68     };
69     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
70       return FAM.getResult<TargetIRAnalysis>(F);
71     };
72     auto GetAC = [this](Function &F) -> AssumptionCache & {
73       return FAM.getResult<AssumptionAnalysis>(F);
74     };
75     auto GetDT = [this](Function &F) -> DominatorTree & {
76       return FAM.getResult<DominatorTreeAnalysis>(F);
77     };
78     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
79       return FAM.getResult<BlockFrequencyAnalysis>(F);
80     };
81 
82     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
83 
84     DominatorTree &DT = GetDT(*F);
85     AssumptionCache &AC = GetAC(*F);
86     Solver->addPredicateInfo(*F, DT, AC);
87 
88     Solver->markBlockExecutable(&F->front());
89     for (Argument &Arg : F->args())
90       Solver->markOverdefined(&Arg);
91     Solver->solveWhileResolvedUndefsIn(*M);
92 
93     removeSSACopy(*F);
94 
95     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
96                                GetAC);
97   }
98 
99   Bonus getInstCost(Instruction &I, bool SizeOnly = false) {
100     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
101     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
102 
103     Cost CodeSize =
104         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
105 
106     Cost Latency =
107         SizeOnly
108             ? 0
109             : BFI.getBlockFreq(I.getParent()).getFrequency() /
110                   BFI.getEntryFreq().getFrequency() *
111                   TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency);
112 
113     return {CodeSize, Latency};
114   }
115 };
116 
117 } // namespace llvm
118 
119 using namespace llvm;
120 
121 TEST_F(FunctionSpecializationTest, SwitchInst) {
122   const char *ModuleString = R"(
123     define void @foo(i32 %a, i32 %b, i32 %i) {
124     entry:
125       br label %loop
126     loop:
127       switch i32 %i, label %default
128       [ i32 1, label %case1
129         i32 2, label %case2 ]
130     case1:
131       %0 = mul i32 %a, 2
132       %1 = sub i32 6, 5
133       br label %bb1
134     case2:
135       %2 = and i32 %b, 3
136       %3 = sdiv i32 8, 2
137       br label %bb2
138     bb1:
139       %4 = add i32 %0, %b
140       br label %loop
141     bb2:
142       %5 = or i32 %2, %a
143       br label %loop
144     default:
145       ret void
146     }
147   )";
148 
149   Module &M = parseModule(ModuleString);
150   Function *F = M.getFunction("foo");
151   FunctionSpecializer Specializer = getSpecializerFor(F);
152   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
153 
154   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
155 
156   auto FuncIter = F->begin();
157   BasicBlock &Loop = *++FuncIter;
158   BasicBlock &Case1 = *++FuncIter;
159   BasicBlock &Case2 = *++FuncIter;
160   BasicBlock &BB1 = *++FuncIter;
161   BasicBlock &BB2 = *++FuncIter;
162 
163   Instruction &Switch = Loop.front();
164   Instruction &Mul = Case1.front();
165   Instruction &And = Case2.front();
166   Instruction &Sdiv = *++Case2.begin();
167   Instruction &BrBB2 = Case2.back();
168   Instruction &Add = BB1.front();
169   Instruction &Or = BB2.front();
170   Instruction &BrLoop = BB2.back();
171 
172   // mul
173   Bonus Ref = getInstCost(Mul);
174   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
175   EXPECT_EQ(Test, Ref);
176   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
177 
178   // and + or + add
179   Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
180   Test = Visitor.getSpecializationBonus(F->getArg(1), One);
181   EXPECT_EQ(Test, Ref);
182   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
183 
184   // switch + sdiv + br + br
185   Ref = getInstCost(Switch) +
186         getInstCost(Sdiv, /*SizeOnly =*/ true) +
187         getInstCost(BrBB2, /*SizeOnly =*/ true) +
188         getInstCost(BrLoop, /*SizeOnly =*/ true);
189   Test = Visitor.getSpecializationBonus(F->getArg(2), One);
190   EXPECT_EQ(Test, Ref);
191   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
192 }
193 
194 TEST_F(FunctionSpecializationTest, BranchInst) {
195   const char *ModuleString = R"(
196     define void @foo(i32 %a, i32 %b, i1 %cond) {
197     entry:
198       br label %loop
199     loop:
200       br i1 %cond, label %bb0, label %bb3
201     bb0:
202       %0 = mul i32 %a, 2
203       %1 = sub i32 6, 5
204       br i1 %cond, label %bb1, label %bb2
205     bb1:
206       %2 = add i32 %0, %b
207       %3 = sdiv i32 8, 2
208       br label %bb2
209     bb2:
210       br label %loop
211     bb3:
212       ret void
213     }
214   )";
215 
216   Module &M = parseModule(ModuleString);
217   Function *F = M.getFunction("foo");
218   FunctionSpecializer Specializer = getSpecializerFor(F);
219   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
220 
221   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
222   Constant *False = ConstantInt::getFalse(M.getContext());
223 
224   auto FuncIter = F->begin();
225   BasicBlock &Loop = *++FuncIter;
226   BasicBlock &BB0 = *++FuncIter;
227   BasicBlock &BB1 = *++FuncIter;
228   BasicBlock &BB2 = *++FuncIter;
229 
230   Instruction &Branch = Loop.front();
231   Instruction &Mul = BB0.front();
232   Instruction &Sub = *++BB0.begin();
233   Instruction &BrBB1BB2 = BB0.back();
234   Instruction &Add = BB1.front();
235   Instruction &Sdiv = *++BB1.begin();
236   Instruction &BrBB2 = BB1.back();
237   Instruction &BrLoop = BB2.front();
238 
239   // mul
240   Bonus Ref = getInstCost(Mul);
241   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
242   EXPECT_EQ(Test, Ref);
243   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
244 
245   // add
246   Ref = getInstCost(Add);
247   Test = Visitor.getSpecializationBonus(F->getArg(1), One);
248   EXPECT_EQ(Test, Ref);
249   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
250 
251   // branch + sub + br + sdiv + br
252   Ref = getInstCost(Branch) +
253         getInstCost(Sub, /*SizeOnly =*/ true) +
254         getInstCost(BrBB1BB2) +
255         getInstCost(Sdiv, /*SizeOnly =*/ true) +
256         getInstCost(BrBB2, /*SizeOnly =*/ true) +
257         getInstCost(BrLoop, /*SizeOnly =*/ true);
258   Test = Visitor.getSpecializationBonus(F->getArg(2), False);
259   EXPECT_EQ(Test, Ref);
260   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
261 }
262 
263 TEST_F(FunctionSpecializationTest, Misc) {
264   const char *ModuleString = R"(
265     %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
266     @g = constant %struct_t zeroinitializer, align 16
267 
268     declare i32 @llvm.smax.i32(i32, i32)
269     declare i32 @bar(i32)
270 
271     define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
272       %cmp = icmp eq i8 %a, 10
273       %ext = zext i1 %cmp to i64
274       %sel = select i1 %cond, i64 %ext, i64 1
275       %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
276       %ld = load i32, ptr %gep
277       %fr = freeze i32 %ld
278       %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
279       %call = call i32 @bar(i32 %smax)
280       %fr2 = freeze i32 %c
281       %add = add i32 %call, %fr2
282       ret i32 %add
283     }
284   )";
285 
286   Module &M = parseModule(ModuleString);
287   Function *F = M.getFunction("foo");
288   FunctionSpecializer Specializer = getSpecializerFor(F);
289   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
290 
291   GlobalVariable *GV = M.getGlobalVariable("g");
292   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
293   Constant *True = ConstantInt::getTrue(M.getContext());
294   Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
295 
296   auto BlockIter = F->front().begin();
297   Instruction &Icmp = *BlockIter++;
298   Instruction &Zext = *BlockIter++;
299   Instruction &Select = *BlockIter++;
300   Instruction &Gep = *BlockIter++;
301   Instruction &Load = *BlockIter++;
302   Instruction &Freeze = *BlockIter++;
303   Instruction &Smax = *BlockIter++;
304 
305   // icmp + zext
306   Bonus Ref = getInstCost(Icmp) + getInstCost(Zext);
307   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
308   EXPECT_EQ(Test, Ref);
309   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
310 
311   // select
312   Ref = getInstCost(Select);
313   Test = Visitor.getSpecializationBonus(F->getArg(1), True);
314   EXPECT_EQ(Test, Ref);
315   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
316 
317   // gep + load + freeze + smax
318   Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
319         getInstCost(Smax);
320   Test = Visitor.getSpecializationBonus(F->getArg(2), GV);
321   EXPECT_EQ(Test, Ref);
322   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
323 
324   Test = Visitor.getSpecializationBonus(F->getArg(3), Undef);
325   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
326 }
327 
328 TEST_F(FunctionSpecializationTest, PhiNode) {
329   const char *ModuleString = R"(
330     define void @foo(i32 %a, i32 %b, i32 %i) {
331     entry:
332       br label %loop
333     loop:
334       %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
335       switch i32 %i, label %default
336       [ i32 1, label %case1
337         i32 2, label %case2 ]
338     case1:
339       %1 = add i32 %0, 1
340       br label %bb
341     case2:
342       %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
343       br label %bb
344     bb:
345       %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
346       %4 = icmp eq i32 %3, 1
347       br i1 %4, label %bb, label %loop
348     default:
349       ret void
350     }
351   )";
352 
353   Module &M = parseModule(ModuleString);
354   Function *F = M.getFunction("foo");
355   FunctionSpecializer Specializer = getSpecializerFor(F);
356   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
357 
358   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
359 
360   auto FuncIter = F->begin();
361   BasicBlock &Loop = *++FuncIter;
362   BasicBlock &Case1 = *++FuncIter;
363   BasicBlock &Case2 = *++FuncIter;
364   BasicBlock &BB = *++FuncIter;
365 
366   Instruction &PhiLoop = Loop.front();
367   Instruction &Switch = Loop.back();
368   Instruction &Add = Case1.front();
369   Instruction &PhiCase2 = Case2.front();
370   Instruction &BrBB = Case2.back();
371   Instruction &PhiBB = BB.front();
372   Instruction &Icmp = *++BB.begin();
373   Instruction &Branch = BB.back();
374 
375   Bonus Test = Visitor.getSpecializationBonus(F->getArg(0), One);
376   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
377 
378   Test = Visitor.getSpecializationBonus(F->getArg(1), One);
379   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
380 
381   // switch + phi + br
382   Bonus Ref = getInstCost(Switch) +
383               getInstCost(PhiCase2, /*SizeOnly =*/ true) +
384               getInstCost(BrBB, /*SizeOnly =*/ true);
385   Test = Visitor.getSpecializationBonus(F->getArg(2), One);
386   EXPECT_EQ(Test, Ref);
387   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
388 
389   // phi + phi + add + icmp + branch
390   Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) +
391         getInstCost(Icmp) + getInstCost(Branch);
392   Test = Visitor.getBonusFromPendingPHIs();
393   EXPECT_EQ(Test, Ref);
394   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
395 }
396 
397