xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision 5f30b1aae0a3e2d3c4c9a50ef4af9457fbea094f)
1 //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/AssumptionCache.h"
10 #include "llvm/Analysis/BlockFrequencyInfo.h"
11 #include "llvm/Analysis/BranchProbabilityInfo.h"
12 #include "llvm/Analysis/LoopInfo.h"
13 #include "llvm/Analysis/PostDominators.h"
14 #include "llvm/Analysis/TargetLibraryInfo.h"
15 #include "llvm/Analysis/TargetTransformInfo.h"
16 #include "llvm/AsmParser/Parser.h"
17 #include "llvm/IR/Constants.h"
18 #include "llvm/IR/PassInstrumentation.h"
19 #include "llvm/Support/SourceMgr.h"
20 #include "llvm/Transforms/IPO/FunctionSpecialization.h"
21 #include "llvm/Transforms/Utils/SCCPSolver.h"
22 #include "gtest/gtest.h"
23 #include <memory>
24 
25 namespace llvm {
26 
27 static void removeSSACopy(Function &F) {
28   for (BasicBlock &BB : F) {
29     for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
30       if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
31         if (II->getIntrinsicID() != Intrinsic::ssa_copy)
32           continue;
33         Inst.replaceAllUsesWith(II->getOperand(0));
34         Inst.eraseFromParent();
35       }
36     }
37   }
38 }
39 
40 class FunctionSpecializationTest : public testing::Test {
41 protected:
42   LLVMContext Ctx;
43   FunctionAnalysisManager FAM;
44   std::unique_ptr<Module> M;
45   std::unique_ptr<SCCPSolver> Solver;
46   SmallVector<Instruction *, 8> KnownConstants;
47 
48   FunctionSpecializationTest() {
49     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
50     FAM.registerPass([&] { return TargetIRAnalysis(); });
51     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
52     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
53     FAM.registerPass([&] { return LoopAnalysis(); });
54     FAM.registerPass([&] { return AssumptionAnalysis(); });
55     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
56     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
57     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
58   }
59 
60   Module &parseModule(const char *ModuleString) {
61     SMDiagnostic Err;
62     M = parseAssemblyString(ModuleString, Err, Ctx);
63     EXPECT_TRUE(M);
64     return *M;
65   }
66 
67   FunctionSpecializer getSpecializerFor(Function *F) {
68     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
69       return FAM.getResult<TargetLibraryAnalysis>(F);
70     };
71     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
72       return FAM.getResult<TargetIRAnalysis>(F);
73     };
74     auto GetAC = [this](Function &F) -> AssumptionCache & {
75       return FAM.getResult<AssumptionAnalysis>(F);
76     };
77     auto GetDT = [this](Function &F) -> DominatorTree & {
78       return FAM.getResult<DominatorTreeAnalysis>(F);
79     };
80     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
81       return FAM.getResult<BlockFrequencyAnalysis>(F);
82     };
83 
84     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
85 
86     DominatorTree &DT = GetDT(*F);
87     AssumptionCache &AC = GetAC(*F);
88     Solver->addPredicateInfo(*F, DT, AC);
89 
90     Solver->markBlockExecutable(&F->front());
91     for (Argument &Arg : F->args())
92       Solver->markOverdefined(&Arg);
93     Solver->solveWhileResolvedUndefsIn(*M);
94 
95     removeSSACopy(*F);
96 
97     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
98                                GetAC);
99   }
100 
101   Cost getCodeSizeSavings(Instruction &I, bool HasLatencySavings = true) {
102     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
103 
104     Cost CodeSize =
105         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
106 
107     if (HasLatencySavings)
108       KnownConstants.push_back(&I);
109 
110     return CodeSize;
111   }
112 
113   Cost getLatencySavings(Function *F) {
114     auto &TTI = FAM.getResult<TargetIRAnalysis>(*F);
115     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*F);
116 
117     Cost Latency = 0;
118     for (const Instruction *I : KnownConstants)
119       Latency += BFI.getBlockFreq(I->getParent()).getFrequency() /
120                  BFI.getEntryFreq().getFrequency() *
121                  TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency);
122 
123     return Latency;
124   }
125 };
126 
127 } // namespace llvm
128 
129 using namespace llvm;
130 
131 TEST_F(FunctionSpecializationTest, SwitchInst) {
132   const char *ModuleString = R"(
133     define void @foo(i32 %a, i32 %b, i32 %i) {
134     entry:
135       br label %loop
136     loop:
137       switch i32 %i, label %default
138       [ i32 1, label %case1
139         i32 2, label %case2 ]
140     case1:
141       %0 = mul i32 %a, 2
142       %1 = sub i32 6, 5
143       br label %bb1
144     case2:
145       %2 = and i32 %b, 3
146       %3 = sdiv i32 8, 2
147       br label %bb2
148     bb1:
149       %4 = add i32 %0, %b
150       br label %loop
151     bb2:
152       %5 = or i32 %2, %a
153       br label %loop
154     default:
155       ret void
156     }
157   )";
158 
159   Module &M = parseModule(ModuleString);
160   Function *F = M.getFunction("foo");
161   FunctionSpecializer Specializer = getSpecializerFor(F);
162   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
163 
164   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
165 
166   auto FuncIter = F->begin();
167   BasicBlock &Loop = *++FuncIter;
168   BasicBlock &Case1 = *++FuncIter;
169   BasicBlock &Case2 = *++FuncIter;
170   BasicBlock &BB1 = *++FuncIter;
171   BasicBlock &BB2 = *++FuncIter;
172 
173   Instruction &Switch = Loop.front();
174   Instruction &Mul = Case1.front();
175   Instruction &And = Case2.front();
176   Instruction &Sdiv = *++Case2.begin();
177   Instruction &BrBB2 = Case2.back();
178   Instruction &Add = BB1.front();
179   Instruction &Or = BB2.front();
180   Instruction &BrLoop = BB2.back();
181 
182   // mul
183   Cost Ref = getCodeSizeSavings(Mul);
184   Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
185   EXPECT_EQ(Test, Ref);
186   EXPECT_TRUE(Test > 0);
187 
188   // and + or + add
189   Ref = getCodeSizeSavings(And) + getCodeSizeSavings(Or) +
190         getCodeSizeSavings(Add);
191   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
192   EXPECT_EQ(Test, Ref);
193   EXPECT_TRUE(Test > 0);
194 
195   // switch + sdiv + br + br
196   Ref = getCodeSizeSavings(Switch) +
197         getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) +
198         getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) +
199         getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false);
200   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One);
201   EXPECT_EQ(Test, Ref);
202   EXPECT_TRUE(Test > 0);
203 
204   // Latency.
205   Ref = getLatencySavings(F);
206   Test = Visitor.getLatencySavingsForKnownConstants();
207   EXPECT_EQ(Test, Ref);
208   EXPECT_TRUE(Test > 0);
209 }
210 
211 TEST_F(FunctionSpecializationTest, BranchInst) {
212   const char *ModuleString = R"(
213     define void @foo(i32 %a, i32 %b, i1 %cond) {
214     entry:
215       br label %loop
216     loop:
217       br i1 %cond, label %bb0, label %bb3
218     bb0:
219       %0 = mul i32 %a, 2
220       %1 = sub i32 6, 5
221       br i1 %cond, label %bb1, label %bb2
222     bb1:
223       %2 = add i32 %0, %b
224       %3 = sdiv i32 8, 2
225       br label %bb2
226     bb2:
227       br label %loop
228     bb3:
229       ret void
230     }
231   )";
232 
233   Module &M = parseModule(ModuleString);
234   Function *F = M.getFunction("foo");
235   FunctionSpecializer Specializer = getSpecializerFor(F);
236   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
237 
238   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
239   Constant *False = ConstantInt::getFalse(M.getContext());
240 
241   auto FuncIter = F->begin();
242   BasicBlock &Loop = *++FuncIter;
243   BasicBlock &BB0 = *++FuncIter;
244   BasicBlock &BB1 = *++FuncIter;
245   BasicBlock &BB2 = *++FuncIter;
246 
247   Instruction &Branch = Loop.front();
248   Instruction &Mul = BB0.front();
249   Instruction &Sub = *++BB0.begin();
250   Instruction &BrBB1BB2 = BB0.back();
251   Instruction &Add = BB1.front();
252   Instruction &Sdiv = *++BB1.begin();
253   Instruction &BrBB2 = BB1.back();
254   Instruction &BrLoop = BB2.front();
255 
256   // mul
257   Cost Ref = getCodeSizeSavings(Mul);
258   Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
259   EXPECT_EQ(Test, Ref);
260   EXPECT_TRUE(Test > 0);
261 
262   // add
263   Ref = getCodeSizeSavings(Add);
264   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
265   EXPECT_EQ(Test, Ref);
266   EXPECT_TRUE(Test > 0);
267 
268   // branch + sub + br + sdiv + br
269   Ref = getCodeSizeSavings(Branch) +
270         getCodeSizeSavings(Sub, /*HasLatencySavings=*/false) +
271         getCodeSizeSavings(BrBB1BB2) +
272         getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) +
273         getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) +
274         getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false);
275   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), False);
276   EXPECT_EQ(Test, Ref);
277   EXPECT_TRUE(Test > 0);
278 
279   // Latency.
280   Ref = getLatencySavings(F);
281   Test = Visitor.getLatencySavingsForKnownConstants();
282   EXPECT_EQ(Test, Ref);
283   EXPECT_TRUE(Test > 0);
284 }
285 
286 TEST_F(FunctionSpecializationTest, SelectInst) {
287   const char *ModuleString = R"(
288     define i32 @foo(i1 %cond, i32 %a, i32 %b) {
289       %sel = select i1 %cond, i32 %a, i32 %b
290       ret i32 %sel
291     }
292   )";
293 
294   Module &M = parseModule(ModuleString);
295   Function *F = M.getFunction("foo");
296   FunctionSpecializer Specializer = getSpecializerFor(F);
297   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
298 
299   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
300   Constant *Zero = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 0);
301   Constant *False = ConstantInt::getFalse(M.getContext());
302   Instruction &Select = *F->front().begin();
303 
304   Cost RefCodeSize = getCodeSizeSavings(Select);
305   Cost RefLatency = getLatencySavings(F);
306 
307   Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False);
308   EXPECT_TRUE(TestCodeSize == 0);
309   TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
310   EXPECT_TRUE(TestCodeSize == 0);
311   Cost TestLatency = Visitor.getLatencySavingsForKnownConstants();
312   EXPECT_TRUE(TestLatency == 0);
313 
314   TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(2), Zero);
315   EXPECT_EQ(TestCodeSize, RefCodeSize);
316   EXPECT_TRUE(TestCodeSize > 0);
317   TestLatency = Visitor.getLatencySavingsForKnownConstants();
318   EXPECT_EQ(TestLatency, RefLatency);
319   EXPECT_TRUE(TestLatency > 0);
320 }
321 
322 TEST_F(FunctionSpecializationTest, Misc) {
323   const char *ModuleString = R"(
324     %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
325     @g = constant %struct_t zeroinitializer, align 16
326 
327     declare i32 @llvm.smax.i32(i32, i32)
328     declare i32 @bar(i32)
329 
330     define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
331       %cmp = icmp eq i8 %a, 10
332       %ext = zext i1 %cmp to i64
333       %sel = select i1 %cond, i64 %ext, i64 1
334       %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
335       %ld = load i32, ptr %gep
336       %fr = freeze i32 %ld
337       %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
338       %call = call i32 @bar(i32 %smax)
339       %fr2 = freeze i32 %c
340       %add = add i32 %call, %fr2
341       ret i32 %add
342     }
343   )";
344 
345   Module &M = parseModule(ModuleString);
346   Function *F = M.getFunction("foo");
347   FunctionSpecializer Specializer = getSpecializerFor(F);
348   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
349 
350   GlobalVariable *GV = M.getGlobalVariable("g");
351   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
352   Constant *True = ConstantInt::getTrue(M.getContext());
353   Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
354 
355   auto BlockIter = F->front().begin();
356   Instruction &Icmp = *BlockIter++;
357   Instruction &Zext = *BlockIter++;
358   Instruction &Select = *BlockIter++;
359   Instruction &Gep = *BlockIter++;
360   Instruction &Load = *BlockIter++;
361   Instruction &Freeze = *BlockIter++;
362   Instruction &Smax = *BlockIter++;
363 
364   // icmp + zext
365   Cost Ref = getCodeSizeSavings(Icmp) + getCodeSizeSavings(Zext);
366   Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
367   EXPECT_EQ(Test, Ref);
368   EXPECT_TRUE(Test > 0);
369 
370   // select
371   Ref = getCodeSizeSavings(Select);
372   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), True);
373   EXPECT_EQ(Test, Ref);
374   EXPECT_TRUE(Test > 0);
375 
376   // gep + load + freeze + smax
377   Ref = getCodeSizeSavings(Gep) + getCodeSizeSavings(Load) +
378         getCodeSizeSavings(Freeze) + getCodeSizeSavings(Smax);
379   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), GV);
380   EXPECT_EQ(Test, Ref);
381   EXPECT_TRUE(Test > 0);
382 
383   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(3), Undef);
384   EXPECT_TRUE(Test == 0);
385 
386   // Latency.
387   Ref = getLatencySavings(F);
388   Test = Visitor.getLatencySavingsForKnownConstants();
389   EXPECT_EQ(Test, Ref);
390   EXPECT_TRUE(Test > 0);
391 }
392 
393 TEST_F(FunctionSpecializationTest, PhiNode) {
394   const char *ModuleString = R"(
395     define void @foo(i32 %a, i32 %b, i32 %i) {
396     entry:
397       br label %loop
398     loop:
399       %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
400       switch i32 %i, label %default
401       [ i32 1, label %case1
402         i32 2, label %case2 ]
403     case1:
404       %1 = add i32 %0, 1
405       br label %bb
406     case2:
407       %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
408       br label %bb
409     bb:
410       %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
411       %4 = icmp eq i32 %3, 1
412       br i1 %4, label %bb, label %loop
413     default:
414       ret void
415     }
416   )";
417 
418   Module &M = parseModule(ModuleString);
419   Function *F = M.getFunction("foo");
420   FunctionSpecializer Specializer = getSpecializerFor(F);
421   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
422 
423   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
424 
425   auto FuncIter = F->begin();
426   BasicBlock &Loop = *++FuncIter;
427   BasicBlock &Case1 = *++FuncIter;
428   BasicBlock &Case2 = *++FuncIter;
429   BasicBlock &BB = *++FuncIter;
430 
431   Instruction &PhiLoop = Loop.front();
432   Instruction &Switch = Loop.back();
433   Instruction &Add = Case1.front();
434   Instruction &PhiCase2 = Case2.front();
435   Instruction &BrBB = Case2.back();
436   Instruction &PhiBB = BB.front();
437   Instruction &Icmp = *++BB.begin();
438   Instruction &Branch = BB.back();
439 
440   Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
441   EXPECT_TRUE(Test == 0);
442 
443   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
444   EXPECT_TRUE(Test == 0);
445 
446   Test = Visitor.getLatencySavingsForKnownConstants();
447   EXPECT_TRUE(Test == 0);
448 
449   // switch + phi + br
450   Cost Ref = getCodeSizeSavings(Switch) +
451              getCodeSizeSavings(PhiCase2, /*HasLatencySavings=*/false) +
452              getCodeSizeSavings(BrBB, /*HasLatencySavings=*/false);
453   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One);
454   EXPECT_EQ(Test, Ref);
455   EXPECT_TRUE(Test > 0 && Test > 0);
456 
457   // phi + phi + add + icmp + branch
458   Ref = getCodeSizeSavings(PhiBB) + getCodeSizeSavings(PhiLoop) +
459         getCodeSizeSavings(Add) + getCodeSizeSavings(Icmp) +
460         getCodeSizeSavings(Branch);
461   Test = Visitor.getCodeSizeSavingsFromPendingPHIs();
462   EXPECT_EQ(Test, Ref);
463   EXPECT_TRUE(Test > 0);
464 
465   // Latency.
466   Ref = getLatencySavings(F);
467   Test = Visitor.getLatencySavingsForKnownConstants();
468   EXPECT_EQ(Test, Ref);
469   EXPECT_TRUE(Test > 0);
470 }
471 
472 TEST_F(FunctionSpecializationTest, BinOp) {
473   // Verify that we can handle binary operators even when only one operand is
474   // constant.
475   const char *ModuleString = R"(
476     define i32 @foo(i1 %a, i1 %b) {
477       %and1 = and i1 %a, %b
478       %and2 = and i1 %b, %and1
479       %sel = select i1 %and2, i32 1, i32 0
480       ret i32 %sel
481     }
482   )";
483 
484   Module &M = parseModule(ModuleString);
485   Function *F = M.getFunction("foo");
486   FunctionSpecializer Specializer = getSpecializerFor(F);
487   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
488 
489   Constant *False = ConstantInt::getFalse(M.getContext());
490   BasicBlock &BB = F->front();
491   Instruction &And1 = BB.front();
492   Instruction &And2 = *++BB.begin();
493   Instruction &Select = *++BB.begin();
494 
495   Cost RefCodeSize = getCodeSizeSavings(And1) + getCodeSizeSavings(And2) +
496                      getCodeSizeSavings(Select);
497   Cost RefLatency = getLatencySavings(F);
498 
499   Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False);
500   Cost TestLatency = Visitor.getLatencySavingsForKnownConstants();
501 
502   EXPECT_EQ(TestCodeSize, RefCodeSize);
503   EXPECT_TRUE(TestCodeSize > 0);
504   EXPECT_EQ(TestLatency, RefLatency);
505   EXPECT_TRUE(TestLatency > 0);
506 }
507