xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision 5400257ded86b65e5b55cbb138e0438b96c9bebf)
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
20 #include "llvm/Transforms/Utils/SCCPSolver.h"
21 #include "gtest/gtest.h"
22 #include <memory>
23 
24 namespace llvm {
25 
26 class FunctionSpecializationTest : public testing::Test {
27 protected:
28   LLVMContext Ctx;
29   FunctionAnalysisManager FAM;
30   std::unique_ptr<Module> M;
31   std::unique_ptr<SCCPSolver> Solver;
32 
33   FunctionSpecializationTest() {
34     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
35     FAM.registerPass([&] { return TargetIRAnalysis(); });
36     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
37     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
38     FAM.registerPass([&] { return LoopAnalysis(); });
39     FAM.registerPass([&] { return AssumptionAnalysis(); });
40     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
41     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
42     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
43   }
44 
45   Module &parseModule(const char *ModuleString) {
46     SMDiagnostic Err;
47     M = parseAssemblyString(ModuleString, Err, Ctx);
48     EXPECT_TRUE(M);
49     return *M;
50   }
51 
52   FunctionSpecializer getSpecializerFor(Function *F) {
53     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
54       return FAM.getResult<TargetLibraryAnalysis>(F);
55     };
56     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
57       return FAM.getResult<TargetIRAnalysis>(F);
58     };
59     auto GetAC = [this](Function &F) -> AssumptionCache & {
60       return FAM.getResult<AssumptionAnalysis>(F);
61     };
62     auto GetDT = [this](Function &F) -> DominatorTree & {
63       return FAM.getResult<DominatorTreeAnalysis>(F);
64     };
65     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
66       return FAM.getResult<BlockFrequencyAnalysis>(F);
67     };
68 
69     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
70 
71     DominatorTree &DT = GetDT(*F);
72     AssumptionCache &AC = GetAC(*F);
73     Solver->addPredicateInfo(*F, DT, AC);
74 
75     Solver->markBlockExecutable(&F->front());
76     for (Argument &Arg : F->args())
77       Solver->markOverdefined(&Arg);
78     Solver->solveWhileResolvedUndefsIn(*M);
79 
80     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
81                                GetAC);
82   }
83 
84   Cost getInstCost(Instruction &I) {
85     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
86     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
87 
88     return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
89          TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
90   }
91 };
92 
93 } // namespace llvm
94 
95 using namespace llvm;
96 
97 TEST_F(FunctionSpecializationTest, SwitchInst) {
98   const char *ModuleString = R"(
99     define void @foo(i32 %a, i32 %b, i32 %i) {
100     entry:
101       br label %loop
102     loop:
103       switch i32 %i, label %default
104       [ i32 1, label %case1
105         i32 2, label %case2 ]
106     case1:
107       %0 = mul i32 %a, 2
108       %1 = sub i32 6, 5
109       br label %bb1
110     case2:
111       %2 = and i32 %b, 3
112       %3 = sdiv i32 8, 2
113       br label %bb2
114     bb1:
115       %4 = add i32 %0, %b
116       br label %loop
117     bb2:
118       %5 = or i32 %2, %a
119       br label %loop
120     default:
121       ret void
122     }
123   )";
124 
125   Module &M = parseModule(ModuleString);
126   Function *F = M.getFunction("foo");
127   FunctionSpecializer Specializer = getSpecializerFor(F);
128   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
129 
130   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
131 
132   auto FuncIter = F->begin();
133   ++FuncIter;
134   BasicBlock &Case1 = *++FuncIter;
135   BasicBlock &Case2 = *++FuncIter;
136   BasicBlock &BB1 = *++FuncIter;
137   BasicBlock &BB2 = *++FuncIter;
138 
139   Instruction &Mul = Case1.front();
140   Instruction &And = Case2.front();
141   Instruction &Sdiv = *++Case2.begin();
142   Instruction &BrBB2 = Case2.back();
143   Instruction &Add = BB1.front();
144   Instruction &Or = BB2.front();
145   Instruction &BrLoop = BB2.back();
146 
147   // mul
148   Cost Ref = getInstCost(Mul);
149   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
150   EXPECT_EQ(Bonus, Ref);
151   EXPECT_TRUE(Bonus > 0);
152 
153   // and + or + add
154   Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
155   Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
156   EXPECT_EQ(Bonus, Ref);
157   EXPECT_TRUE(Bonus > 0);
158 
159   // sdiv + br + br
160   Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrLoop);
161   Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
162   EXPECT_EQ(Bonus, Ref);
163   EXPECT_TRUE(Bonus > 0);
164 }
165 
166 TEST_F(FunctionSpecializationTest, BranchInst) {
167   const char *ModuleString = R"(
168     define void @foo(i32 %a, i32 %b, i1 %cond) {
169     entry:
170       br label %loop
171     loop:
172       br i1 %cond, label %bb0, label %bb2
173     bb0:
174       %0 = mul i32 %a, 2
175       %1 = sub i32 6, 5
176       br label %bb1
177     bb1:
178       %2 = add i32 %0, %b
179       %3 = sdiv i32 8, 2
180       br label %loop
181     bb2:
182       ret void
183     }
184   )";
185 
186   Module &M = parseModule(ModuleString);
187   Function *F = M.getFunction("foo");
188   FunctionSpecializer Specializer = getSpecializerFor(F);
189   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
190 
191   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
192   Constant *False = ConstantInt::getFalse(M.getContext());
193 
194   auto FuncIter = F->begin();
195   ++FuncIter;
196   BasicBlock &BB0 = *++FuncIter;
197   BasicBlock &BB1 = *++FuncIter;
198 
199   Instruction &Mul = BB0.front();
200   Instruction &Sub = *++BB0.begin();
201   Instruction &BrBB1 = BB0.back();
202   Instruction &Add = BB1.front();
203   Instruction &Sdiv = *++BB1.begin();
204   Instruction &BrLoop = BB1.back();
205 
206   // mul
207   Cost Ref = getInstCost(Mul);
208   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
209   EXPECT_EQ(Bonus, Ref);
210   EXPECT_TRUE(Bonus > 0);
211 
212   // add
213   Ref = getInstCost(Add);
214   Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
215   EXPECT_EQ(Bonus, Ref);
216   EXPECT_TRUE(Bonus > 0);
217 
218   // sub + br + sdiv + br
219   Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) +
220         getInstCost(BrLoop);
221   Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
222   EXPECT_EQ(Bonus, Ref);
223   EXPECT_TRUE(Bonus > 0);
224 }
225 
226 TEST_F(FunctionSpecializationTest, Misc) {
227   const char *ModuleString = R"(
228     @g = constant [2 x i32] zeroinitializer, align 4
229 
230     declare i32 @llvm.smax.i32(i32, i32)
231     declare i32 @bar(i32)
232 
233     define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
234       %cmp = icmp eq i8 %a, 10
235       %ext = zext i1 %cmp to i32
236       %sel = select i1 %cond, i32 %ext, i32 1
237       %gep = getelementptr i32, ptr %b, i32 %sel
238       %ld = load i32, ptr %gep
239       %fr = freeze i32 %ld
240       %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
241       %call = call i32 @bar(i32 %smax)
242       %fr2 = freeze i32 %c
243       %add = add i32 %call, %fr2
244       ret i32 %add
245     }
246   )";
247 
248   Module &M = parseModule(ModuleString);
249   Function *F = M.getFunction("foo");
250   FunctionSpecializer Specializer = getSpecializerFor(F);
251   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
252 
253   GlobalVariable *GV = M.getGlobalVariable("g");
254   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
255   Constant *True = ConstantInt::getTrue(M.getContext());
256   Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
257 
258   auto BlockIter = F->front().begin();
259   Instruction &Icmp = *BlockIter++;
260   Instruction &Zext = *BlockIter++;
261   Instruction &Select = *BlockIter++;
262   Instruction &Gep = *BlockIter++;
263   Instruction &Load = *BlockIter++;
264   Instruction &Freeze = *BlockIter++;
265   Instruction &Smax = *BlockIter++;
266 
267   // icmp + zext
268   Cost Ref = getInstCost(Icmp) + getInstCost(Zext);
269   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
270   EXPECT_EQ(Bonus, Ref);
271   EXPECT_TRUE(Bonus > 0);
272 
273   // select
274   Ref = getInstCost(Select);
275   Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
276   EXPECT_EQ(Bonus, Ref);
277   EXPECT_TRUE(Bonus > 0);
278 
279   // gep + load + freeze + smax
280   Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
281         getInstCost(Smax);
282   Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
283   EXPECT_EQ(Bonus, Ref);
284   EXPECT_TRUE(Bonus > 0);
285 
286   Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
287   EXPECT_TRUE(Bonus == 0);
288 }
289