xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision 5bfefff1c44fd992b673e1ff9c9f1865f9d81af1)
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
20 #include "llvm/Transforms/Utils/SCCPSolver.h"
21 #include "gtest/gtest.h"
22 #include <memory>
23 
24 namespace llvm {
25 
26 static void removeSSACopy(Function &F) {
27   for (BasicBlock &BB : F) {
28     for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
29       if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
30         if (II->getIntrinsicID() != Intrinsic::ssa_copy)
31           continue;
32         Inst.replaceAllUsesWith(II->getOperand(0));
33         Inst.eraseFromParent();
34       }
35     }
36   }
37 }
38 
39 class FunctionSpecializationTest : public testing::Test {
40 protected:
41   LLVMContext Ctx;
42   FunctionAnalysisManager FAM;
43   std::unique_ptr<Module> M;
44   std::unique_ptr<SCCPSolver> Solver;
45 
46   FunctionSpecializationTest() {
47     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
48     FAM.registerPass([&] { return TargetIRAnalysis(); });
49     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
50     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
51     FAM.registerPass([&] { return LoopAnalysis(); });
52     FAM.registerPass([&] { return AssumptionAnalysis(); });
53     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
54     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
55     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
56   }
57 
58   Module &parseModule(const char *ModuleString) {
59     SMDiagnostic Err;
60     M = parseAssemblyString(ModuleString, Err, Ctx);
61     EXPECT_TRUE(M);
62     return *M;
63   }
64 
65   FunctionSpecializer getSpecializerFor(Function *F) {
66     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
67       return FAM.getResult<TargetLibraryAnalysis>(F);
68     };
69     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
70       return FAM.getResult<TargetIRAnalysis>(F);
71     };
72     auto GetAC = [this](Function &F) -> AssumptionCache & {
73       return FAM.getResult<AssumptionAnalysis>(F);
74     };
75     auto GetDT = [this](Function &F) -> DominatorTree & {
76       return FAM.getResult<DominatorTreeAnalysis>(F);
77     };
78     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
79       return FAM.getResult<BlockFrequencyAnalysis>(F);
80     };
81 
82     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
83 
84     DominatorTree &DT = GetDT(*F);
85     AssumptionCache &AC = GetAC(*F);
86     Solver->addPredicateInfo(*F, DT, AC);
87 
88     Solver->markBlockExecutable(&F->front());
89     for (Argument &Arg : F->args())
90       Solver->markOverdefined(&Arg);
91     Solver->solveWhileResolvedUndefsIn(*M);
92 
93     removeSSACopy(*F);
94 
95     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
96                                GetAC);
97   }
98 
99   Bonus getInstCost(Instruction &I, bool SizeOnly = false) {
100     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
101     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
102 
103     Cost CodeSize =
104         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
105 
106     Cost Latency = SizeOnly ? 0 :
107         BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
108         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_Latency);
109 
110     return {CodeSize, Latency};
111   }
112 };
113 
114 } // namespace llvm
115 
116 using namespace llvm;
117 
118 TEST_F(FunctionSpecializationTest, SwitchInst) {
119   const char *ModuleString = R"(
120     define void @foo(i32 %a, i32 %b, i32 %i) {
121     entry:
122       br label %loop
123     loop:
124       switch i32 %i, label %default
125       [ i32 1, label %case1
126         i32 2, label %case2 ]
127     case1:
128       %0 = mul i32 %a, 2
129       %1 = sub i32 6, 5
130       br label %bb1
131     case2:
132       %2 = and i32 %b, 3
133       %3 = sdiv i32 8, 2
134       br label %bb2
135     bb1:
136       %4 = add i32 %0, %b
137       br label %loop
138     bb2:
139       %5 = or i32 %2, %a
140       br label %loop
141     default:
142       ret void
143     }
144   )";
145 
146   Module &M = parseModule(ModuleString);
147   Function *F = M.getFunction("foo");
148   FunctionSpecializer Specializer = getSpecializerFor(F);
149   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
150 
151   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
152 
153   auto FuncIter = F->begin();
154   BasicBlock &Loop = *++FuncIter;
155   BasicBlock &Case1 = *++FuncIter;
156   BasicBlock &Case2 = *++FuncIter;
157   BasicBlock &BB1 = *++FuncIter;
158   BasicBlock &BB2 = *++FuncIter;
159 
160   Instruction &Switch = Loop.front();
161   Instruction &Mul = Case1.front();
162   Instruction &And = Case2.front();
163   Instruction &Sdiv = *++Case2.begin();
164   Instruction &BrBB2 = Case2.back();
165   Instruction &Add = BB1.front();
166   Instruction &Or = BB2.front();
167   Instruction &BrLoop = BB2.back();
168 
169   // mul
170   Bonus Ref = getInstCost(Mul);
171   Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
172   EXPECT_EQ(Test, Ref);
173   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
174 
175   // and + or + add
176   Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
177   Test = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
178   EXPECT_EQ(Test, Ref);
179   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
180 
181   // switch + sdiv + br + br
182   Ref = getInstCost(Switch) +
183         getInstCost(Sdiv, /*SizeOnly =*/ true) +
184         getInstCost(BrBB2, /*SizeOnly =*/ true) +
185         getInstCost(BrLoop, /*SizeOnly =*/ true);
186   Test = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
187   EXPECT_EQ(Test, Ref);
188   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
189 }
190 
191 TEST_F(FunctionSpecializationTest, BranchInst) {
192   const char *ModuleString = R"(
193     define void @foo(i32 %a, i32 %b, i1 %cond) {
194     entry:
195       br label %loop
196     loop:
197       br i1 %cond, label %bb0, label %bb2
198     bb0:
199       %0 = mul i32 %a, 2
200       %1 = sub i32 6, 5
201       br label %bb1
202     bb1:
203       %2 = add i32 %0, %b
204       %3 = sdiv i32 8, 2
205       br label %loop
206     bb2:
207       ret void
208     }
209   )";
210 
211   Module &M = parseModule(ModuleString);
212   Function *F = M.getFunction("foo");
213   FunctionSpecializer Specializer = getSpecializerFor(F);
214   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
215 
216   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
217   Constant *False = ConstantInt::getFalse(M.getContext());
218 
219   auto FuncIter = F->begin();
220   BasicBlock &Loop = *++FuncIter;
221   BasicBlock &BB0 = *++FuncIter;
222   BasicBlock &BB1 = *++FuncIter;
223 
224   Instruction &Branch = Loop.front();
225   Instruction &Mul = BB0.front();
226   Instruction &Sub = *++BB0.begin();
227   Instruction &BrBB1 = BB0.back();
228   Instruction &Add = BB1.front();
229   Instruction &Sdiv = *++BB1.begin();
230   Instruction &BrLoop = BB1.back();
231 
232   // mul
233   Bonus Ref = getInstCost(Mul);
234   Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
235   EXPECT_EQ(Test, Ref);
236   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
237 
238   // add
239   Ref = getInstCost(Add);
240   Test = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
241   EXPECT_EQ(Test, Ref);
242   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
243 
244   // branch + sub + br + sdiv + br
245   Ref = getInstCost(Branch) +
246         getInstCost(Sub, /*SizeOnly =*/ true) +
247         getInstCost(BrBB1, /*SizeOnly =*/ true) +
248         getInstCost(Sdiv, /*SizeOnly =*/ true) +
249         getInstCost(BrLoop, /*SizeOnly =*/ true);
250   Test = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
251   EXPECT_EQ(Test, Ref);
252   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
253 }
254 
255 TEST_F(FunctionSpecializationTest, Misc) {
256   const char *ModuleString = R"(
257     %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
258     @g = constant %struct_t zeroinitializer, align 16
259 
260     declare i32 @llvm.smax.i32(i32, i32)
261     declare i32 @bar(i32)
262 
263     define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
264       %cmp = icmp eq i8 %a, 10
265       %ext = zext i1 %cmp to i64
266       %sel = select i1 %cond, i64 %ext, i64 1
267       %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
268       %ld = load i32, ptr %gep
269       %fr = freeze i32 %ld
270       %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
271       %call = call i32 @bar(i32 %smax)
272       %fr2 = freeze i32 %c
273       %add = add i32 %call, %fr2
274       ret i32 %add
275     }
276   )";
277 
278   Module &M = parseModule(ModuleString);
279   Function *F = M.getFunction("foo");
280   FunctionSpecializer Specializer = getSpecializerFor(F);
281   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
282 
283   GlobalVariable *GV = M.getGlobalVariable("g");
284   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
285   Constant *True = ConstantInt::getTrue(M.getContext());
286   Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
287 
288   auto BlockIter = F->front().begin();
289   Instruction &Icmp = *BlockIter++;
290   Instruction &Zext = *BlockIter++;
291   Instruction &Select = *BlockIter++;
292   Instruction &Gep = *BlockIter++;
293   Instruction &Load = *BlockIter++;
294   Instruction &Freeze = *BlockIter++;
295   Instruction &Smax = *BlockIter++;
296 
297   // icmp + zext
298   Bonus Ref = getInstCost(Icmp) + getInstCost(Zext);
299   Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
300   EXPECT_EQ(Test, Ref);
301   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
302 
303   // select
304   Ref = getInstCost(Select);
305   Test = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
306   EXPECT_EQ(Test, Ref);
307   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
308 
309   // gep + load + freeze + smax
310   Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
311         getInstCost(Smax);
312   Test = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
313   EXPECT_EQ(Test, Ref);
314   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
315 
316   Test = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
317   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
318 }
319 
320 TEST_F(FunctionSpecializationTest, PhiNode) {
321   const char *ModuleString = R"(
322     define void @foo(i32 %a, i32 %b, i32 %i) {
323     entry:
324       br label %loop
325     loop:
326       %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
327       switch i32 %i, label %default
328       [ i32 1, label %case1
329         i32 2, label %case2 ]
330     case1:
331       %1 = add i32 %0, 1
332       br label %bb
333     case2:
334       %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
335       br label %bb
336     bb:
337       %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
338       %4 = icmp eq i32 %3, 1
339       br i1 %4, label %bb, label %loop
340     default:
341       ret void
342     }
343   )";
344 
345   Module &M = parseModule(ModuleString);
346   Function *F = M.getFunction("foo");
347   FunctionSpecializer Specializer = getSpecializerFor(F);
348   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
349 
350   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
351 
352   auto FuncIter = F->begin();
353   BasicBlock &Loop = *++FuncIter;
354   BasicBlock &Case1 = *++FuncIter;
355   BasicBlock &Case2 = *++FuncIter;
356   BasicBlock &BB = *++FuncIter;
357 
358   Instruction &PhiLoop = Loop.front();
359   Instruction &Switch = Loop.back();
360   Instruction &Add = Case1.front();
361   Instruction &PhiCase2 = Case2.front();
362   Instruction &BrBB = Case2.back();
363   Instruction &PhiBB = BB.front();
364   Instruction &Icmp = *++BB.begin();
365   Instruction &Branch = BB.back();
366 
367   Bonus Test = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
368   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
369 
370   Test = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
371   EXPECT_TRUE(Test.CodeSize == 0 && Test.Latency == 0);
372 
373   // switch + phi + br
374   Bonus Ref = getInstCost(Switch) +
375               getInstCost(PhiCase2, /*SizeOnly =*/ true) +
376               getInstCost(BrBB, /*SizeOnly =*/ true);
377   Test = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
378   EXPECT_EQ(Test, Ref);
379   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
380 
381   // phi + phi + add + icmp + branch
382   Ref = getInstCost(PhiBB) + getInstCost(PhiLoop) + getInstCost(Add) +
383         getInstCost(Icmp) + getInstCost(Branch);
384   Test = Visitor.getBonusFromPendingPHIs();
385   EXPECT_EQ(Test, Ref);
386   EXPECT_TRUE(Test.CodeSize > 0 && Test.Latency > 0);
387 }
388 
389