xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision 32683b231e09fca7d1ee5f5d81627edbcbe59213)
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/Support/SourceMgr.h"
19 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
20 #include "llvm/Transforms/Utils/SCCPSolver.h"
21 #include "gtest/gtest.h"
22 #include <memory>
23 
24 namespace llvm {
25 
26 static void removeSSACopy(Function &F) {
27   for (BasicBlock &BB : F) {
28     for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
29       if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
30         if (II->getIntrinsicID() != Intrinsic::ssa_copy)
31           continue;
32         Inst.replaceAllUsesWith(II->getOperand(0));
33         Inst.eraseFromParent();
34       }
35     }
36   }
37 }
38 
39 class FunctionSpecializationTest : public testing::Test {
40 protected:
41   LLVMContext Ctx;
42   FunctionAnalysisManager FAM;
43   std::unique_ptr<Module> M;
44   std::unique_ptr<SCCPSolver> Solver;
45 
46   FunctionSpecializationTest() {
47     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
48     FAM.registerPass([&] { return TargetIRAnalysis(); });
49     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
50     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
51     FAM.registerPass([&] { return LoopAnalysis(); });
52     FAM.registerPass([&] { return AssumptionAnalysis(); });
53     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
54     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
55     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
56   }
57 
58   Module &parseModule(const char *ModuleString) {
59     SMDiagnostic Err;
60     M = parseAssemblyString(ModuleString, Err, Ctx);
61     EXPECT_TRUE(M);
62     return *M;
63   }
64 
65   FunctionSpecializer getSpecializerFor(Function *F) {
66     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
67       return FAM.getResult<TargetLibraryAnalysis>(F);
68     };
69     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
70       return FAM.getResult<TargetIRAnalysis>(F);
71     };
72     auto GetAC = [this](Function &F) -> AssumptionCache & {
73       return FAM.getResult<AssumptionAnalysis>(F);
74     };
75     auto GetDT = [this](Function &F) -> DominatorTree & {
76       return FAM.getResult<DominatorTreeAnalysis>(F);
77     };
78     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
79       return FAM.getResult<BlockFrequencyAnalysis>(F);
80     };
81 
82     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
83 
84     DominatorTree &DT = GetDT(*F);
85     AssumptionCache &AC = GetAC(*F);
86     Solver->addPredicateInfo(*F, DT, AC);
87 
88     Solver->markBlockExecutable(&F->front());
89     for (Argument &Arg : F->args())
90       Solver->markOverdefined(&Arg);
91     Solver->solveWhileResolvedUndefsIn(*M);
92 
93     removeSSACopy(*F);
94 
95     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
96                                GetAC);
97   }
98 
99   Cost getInstCost(Instruction &I) {
100     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
101     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*I.getFunction());
102 
103     return BFI.getBlockFreq(I.getParent()).getFrequency() / BFI.getEntryFreq() *
104          TTI.getInstructionCost(&I, TargetTransformInfo::TCK_SizeAndLatency);
105   }
106 };
107 
108 } // namespace llvm
109 
110 using namespace llvm;
111 
112 TEST_F(FunctionSpecializationTest, SwitchInst) {
113   const char *ModuleString = R"(
114     define void @foo(i32 %a, i32 %b, i32 %i) {
115     entry:
116       br label %loop
117     loop:
118       switch i32 %i, label %default
119       [ i32 1, label %case1
120         i32 2, label %case2 ]
121     case1:
122       %0 = mul i32 %a, 2
123       %1 = sub i32 6, 5
124       br label %bb1
125     case2:
126       %2 = and i32 %b, 3
127       %3 = sdiv i32 8, 2
128       br label %bb2
129     bb1:
130       %4 = add i32 %0, %b
131       br label %loop
132     bb2:
133       %5 = or i32 %2, %a
134       br label %loop
135     default:
136       ret void
137     }
138   )";
139 
140   Module &M = parseModule(ModuleString);
141   Function *F = M.getFunction("foo");
142   FunctionSpecializer Specializer = getSpecializerFor(F);
143   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
144 
145   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
146 
147   auto FuncIter = F->begin();
148   ++FuncIter;
149   BasicBlock &Case1 = *++FuncIter;
150   BasicBlock &Case2 = *++FuncIter;
151   BasicBlock &BB1 = *++FuncIter;
152   BasicBlock &BB2 = *++FuncIter;
153 
154   Instruction &Mul = Case1.front();
155   Instruction &And = Case2.front();
156   Instruction &Sdiv = *++Case2.begin();
157   Instruction &BrBB2 = Case2.back();
158   Instruction &Add = BB1.front();
159   Instruction &Or = BB2.front();
160   Instruction &BrLoop = BB2.back();
161 
162   // mul
163   Cost Ref = getInstCost(Mul);
164   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
165   EXPECT_EQ(Bonus, Ref);
166   EXPECT_TRUE(Bonus > 0);
167 
168   // and + or + add
169   Ref = getInstCost(And) + getInstCost(Or) + getInstCost(Add);
170   Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
171   EXPECT_EQ(Bonus, Ref);
172   EXPECT_TRUE(Bonus > 0);
173 
174   // sdiv + br + br
175   Ref = getInstCost(Sdiv) + getInstCost(BrBB2) + getInstCost(BrLoop);
176   Bonus = Specializer.getSpecializationBonus(F->getArg(2), One, Visitor);
177   EXPECT_EQ(Bonus, Ref);
178   EXPECT_TRUE(Bonus > 0);
179 }
180 
181 TEST_F(FunctionSpecializationTest, BranchInst) {
182   const char *ModuleString = R"(
183     define void @foo(i32 %a, i32 %b, i1 %cond) {
184     entry:
185       br label %loop
186     loop:
187       br i1 %cond, label %bb0, label %bb2
188     bb0:
189       %0 = mul i32 %a, 2
190       %1 = sub i32 6, 5
191       br label %bb1
192     bb1:
193       %2 = add i32 %0, %b
194       %3 = sdiv i32 8, 2
195       br label %loop
196     bb2:
197       ret void
198     }
199   )";
200 
201   Module &M = parseModule(ModuleString);
202   Function *F = M.getFunction("foo");
203   FunctionSpecializer Specializer = getSpecializerFor(F);
204   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
205 
206   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
207   Constant *False = ConstantInt::getFalse(M.getContext());
208 
209   auto FuncIter = F->begin();
210   ++FuncIter;
211   BasicBlock &BB0 = *++FuncIter;
212   BasicBlock &BB1 = *++FuncIter;
213 
214   Instruction &Mul = BB0.front();
215   Instruction &Sub = *++BB0.begin();
216   Instruction &BrBB1 = BB0.back();
217   Instruction &Add = BB1.front();
218   Instruction &Sdiv = *++BB1.begin();
219   Instruction &BrLoop = BB1.back();
220 
221   // mul
222   Cost Ref = getInstCost(Mul);
223   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
224   EXPECT_EQ(Bonus, Ref);
225   EXPECT_TRUE(Bonus > 0);
226 
227   // add
228   Ref = getInstCost(Add);
229   Bonus = Specializer.getSpecializationBonus(F->getArg(1), One, Visitor);
230   EXPECT_EQ(Bonus, Ref);
231   EXPECT_TRUE(Bonus > 0);
232 
233   // sub + br + sdiv + br
234   Ref = getInstCost(Sub) + getInstCost(BrBB1) + getInstCost(Sdiv) +
235         getInstCost(BrLoop);
236   Bonus = Specializer.getSpecializationBonus(F->getArg(2), False, Visitor);
237   EXPECT_EQ(Bonus, Ref);
238   EXPECT_TRUE(Bonus > 0);
239 }
240 
241 TEST_F(FunctionSpecializationTest, Misc) {
242   const char *ModuleString = R"(
243     %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
244     @g = constant %struct_t zeroinitializer, align 16
245 
246     declare i32 @llvm.smax.i32(i32, i32)
247     declare i32 @bar(i32)
248 
249     define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
250       %cmp = icmp eq i8 %a, 10
251       %ext = zext i1 %cmp to i64
252       %sel = select i1 %cond, i64 %ext, i64 1
253       %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
254       %ld = load i32, ptr %gep
255       %fr = freeze i32 %ld
256       %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
257       %call = call i32 @bar(i32 %smax)
258       %fr2 = freeze i32 %c
259       %add = add i32 %call, %fr2
260       ret i32 %add
261     }
262   )";
263 
264   Module &M = parseModule(ModuleString);
265   Function *F = M.getFunction("foo");
266   FunctionSpecializer Specializer = getSpecializerFor(F);
267   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
268 
269   GlobalVariable *GV = M.getGlobalVariable("g");
270   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
271   Constant *True = ConstantInt::getTrue(M.getContext());
272   Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
273 
274   auto BlockIter = F->front().begin();
275   Instruction &Icmp = *BlockIter++;
276   Instruction &Zext = *BlockIter++;
277   Instruction &Select = *BlockIter++;
278   Instruction &Gep = *BlockIter++;
279   Instruction &Load = *BlockIter++;
280   Instruction &Freeze = *BlockIter++;
281   Instruction &Smax = *BlockIter++;
282 
283   // icmp + zext
284   Cost Ref = getInstCost(Icmp) + getInstCost(Zext);
285   Cost Bonus = Specializer.getSpecializationBonus(F->getArg(0), One, Visitor);
286   EXPECT_EQ(Bonus, Ref);
287   EXPECT_TRUE(Bonus > 0);
288 
289   // select
290   Ref = getInstCost(Select);
291   Bonus = Specializer.getSpecializationBonus(F->getArg(1), True, Visitor);
292   EXPECT_EQ(Bonus, Ref);
293   EXPECT_TRUE(Bonus > 0);
294 
295   // gep + load + freeze + smax
296   Ref = getInstCost(Gep) + getInstCost(Load) + getInstCost(Freeze) +
297         getInstCost(Smax);
298   Bonus = Specializer.getSpecializationBonus(F->getArg(2), GV, Visitor);
299   EXPECT_EQ(Bonus, Ref);
300   EXPECT_TRUE(Bonus > 0);
301 
302   Bonus = Specializer.getSpecializationBonus(F->getArg(3), Undef, Visitor);
303   EXPECT_TRUE(Bonus == 0);
304 }
305