xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision ced90d1ff64a89a13479a37a3b17a411a3259f9f)
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
20 #include "llvm/Transforms/Utils/SCCPSolver.h"
21 #include "gtest/gtest.h"
22 #include <memory>
23 
24 namespace llvm {
25 
26 class FunctionSpecializationTest : public testing::Test {
27 protected:
28   LLVMContext Ctx;
29   FunctionAnalysisManager FAM;
30   std::unique_ptr<Module> M;
31   std::unique_ptr<SCCPSolver> Solver;
32 
33   FunctionSpecializationTest() {
34     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
35     FAM.registerPass([&] { return TargetIRAnalysis(); });
36     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
37     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
38     FAM.registerPass([&] { return LoopAnalysis(); });
39     FAM.registerPass([&] { return AssumptionAnalysis(); });
40     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
41     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
42     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
43   }
44 
45   Module &parseModule(const char *ModuleString) {
46     SMDiagnostic Err;
47     M = parseAssemblyString(ModuleString, Err, Ctx);
48     EXPECT_TRUE(M);
49     return *M;
50   }
51 
52   FunctionSpecializer getSpecializerFor(Function *F) {
53     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
54       return FAM.getResult<TargetLibraryAnalysis>(F);
55     };
56     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
57       return FAM.getResult<TargetIRAnalysis>(F);
58     };
59     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
60       return FAM.getResult<BlockFrequencyAnalysis>(F);
61     };
62     auto GetAC = [this](Function &F) -> AssumptionCache & {
63       return FAM.getResult<AssumptionAnalysis>(F);
64     };
65     auto GetAnalysis = [this](Function &F) -> AnalysisResultsForFn {
66       DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F);
67       return { std::make_unique<PredicateInfo>(F, DT,
68                                 FAM.getResult<AssumptionAnalysis>(F)),
69                &DT, FAM.getCachedResult<PostDominatorTreeAnalysis>(F) };
70     };
71 
72     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
73 
74     Solver->addAnalysis(*F, GetAnalysis(*F));
75     Solver->markBlockExecutable(&F->front());
76     for (Argument &Arg : F->args())
77       Solver->markOverdefined(&Arg);
78     Solver->solveWhileResolvedUndefsIn(*M);
79 
80     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
81                                GetAC);
82   }
83 
84   Cost getInstCost(Instruction &I) {
85     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
86     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
87 
88     return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
89          TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
90   }
91 };
92 
93 } // namespace llvm
94 
95 using namespace llvm;
96 
97 TEST_F(FunctionSpecializationTest, SwitchInst) {
98   const char *ModuleString = R"(
99     define void @foo(i32 %a, i32 %b, i32 %i) {
100     entry:
101       switch i32 %i, label %default
102       [ i32 1, label %case1
103         i32 2, label %case2 ]
104     case1:
105       %0 = mul i32 %a, 2
106       %1 = sub i32 6, 5
107       br label %bb1
108     case2:
109       %2 = and i32 %b, 3
110       %3 = sdiv i32 8, 2
111       br label %bb2
112     bb1:
113       %4 = add i32 %0, %b
114       br label %default
115     bb2:
116       %5 = or i32 %2, %a
117       br label %default
118     default:
119       ret void
120     }
121   )";
122 
123   Module &M = parseModule(ModuleString);
124   Function *F = M.getFunction("foo");
125   FunctionSpecializer Specializer = getSpecializerFor(F);
126   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
127 
128   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
129 
130   auto FuncIter = F->begin();
131   BasicBlock &Case1 = *++FuncIter;
132   BasicBlock &Case2 = *++FuncIter;
133   BasicBlock &BB1 = *++FuncIter;
134   BasicBlock &BB2 = *++FuncIter;
135 
136   Instruction &Mul = Case1.front();
137   Instruction &And = Case2.front();
138   Instruction &Sdiv = *++Case2.begin();
139   Instruction &BrBB2 = Case2.back();
140   Instruction &Add = BB1.front();
141   Instruction &Or = BB2.front();
142   Instruction &BrDefault = BB2.back();
143 
144   // mul
145   Cost Ref = getInstCost(Mul);
146   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
147   EXPECT_EQ(Bonus, Ref);
148 
149   // and + or + add
150   Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
151   Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
152   EXPECT_EQ(Bonus, Ref);
153 
154   // sdiv + br + br
155   Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrDefault);
156   Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
157   EXPECT_EQ(Bonus, Ref);
158 }
159 
160 TEST_F(FunctionSpecializationTest, BranchInst) {
161   const char *ModuleString = R"(
162     define void @foo(i32 %a, i32 %b, i1 %cond) {
163     entry:
164       br i1 %cond, label %bb0, label %bb2
165     bb0:
166       %0 = mul i32 %a, 2
167       %1 = sub i32 6, 5
168       br label %bb1
169     bb1:
170       %2 = add i32 %0, %b
171       %3 = sdiv i32 8, 2
172       br label %bb2
173     bb2:
174       ret void
175     }
176   )";
177 
178   Module &M = parseModule(ModuleString);
179   Function *F = M.getFunction("foo");
180   FunctionSpecializer Specializer = getSpecializerFor(F);
181   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
182 
183   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
184   Constant *False = ConstantInt::getFalse(M.getContext());
185 
186   auto FuncIter = F->begin();
187   BasicBlock &BB0 = *++FuncIter;
188   BasicBlock &BB1 = *++FuncIter;
189 
190   Instruction &Mul = BB0.front();
191   Instruction &Sub = *++BB0.begin();
192   Instruction &BrBB1 = BB0.back();
193   Instruction &Add = BB1.front();
194   Instruction &Sdiv = *++BB1.begin();
195   Instruction &BrBB2 = BB1.back();
196 
197   // mul
198   Cost Ref = getInstCost(Mul);
199   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
200   EXPECT_EQ(Bonus, Ref);
201 
202   // add
203   Ref = getInstCost(Add);
204   Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
205   EXPECT_EQ(Bonus, Ref);
206 
207   // sub + br + sdiv + br
208   Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) +
209         getInstCost(BrBB2);
210   Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
211   EXPECT_EQ(Bonus, Ref);
212 }
213 
214 TEST_F(FunctionSpecializationTest, Misc) {
215   const char *ModuleString = R"(
216     @g = constant [2 x i32] zeroinitializer, align 4
217 
218     define i32 @foo(i8 %a, i1 %cond, ptr %b) {
219       %cmp = icmp eq i8 %a, 10
220       %ext = zext i1 %cmp to i32
221       %sel = select i1 %cond, i32 %ext, i32 1
222       %gep = getelementptr i32, ptr %b, i32 %sel
223       %ld = load i32, ptr %gep
224       ret i32 %ld
225     }
226   )";
227 
228   Module &M = parseModule(ModuleString);
229   Function *F = M.getFunction("foo");
230   FunctionSpecializer Specializer = getSpecializerFor(F);
231   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
232 
233   GlobalVariable *GV = M.getGlobalVariable("g");
234   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
235   Constant *True = ConstantInt::getTrue(M.getContext());
236 
237   auto BlockIter = F->front().begin();
238   Instruction &Icmp = *BlockIter++;
239   Instruction &Zext = *BlockIter++;
240   Instruction &Select = *BlockIter++;
241   Instruction &Gep = *BlockIter++;
242   Instruction &Load = *BlockIter++;
243 
244   // icmp + zext
245   Cost Ref = getInstCost(Icmp) + getInstCost(Zext);
246   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
247   EXPECT_EQ(Bonus, Ref);
248 
249   // select
250   Ref = getInstCost(Select);
251   Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
252   EXPECT_EQ(Bonus, Ref);
253 
254   // gep + load
255   Ref = getInstCost(Gep) + getInstCost(Load);
256   Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
257   EXPECT_EQ(Bonus, Ref);
258 }
259