xref: /llvm-project/llvm/unittests/Transforms/IPO/FunctionSpecializationTest.cpp (revision 5f30b1aae0a3e2d3c4c9a50ef4af9457fbea094f)
14d13896dSAlexandros Lamprineas //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===//
24d13896dSAlexandros Lamprineas //
34d13896dSAlexandros Lamprineas // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44d13896dSAlexandros Lamprineas // See https://llvm.org/LICENSE.txt for license information.
54d13896dSAlexandros Lamprineas // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
64d13896dSAlexandros Lamprineas //
74d13896dSAlexandros Lamprineas //===----------------------------------------------------------------------===//
84d13896dSAlexandros Lamprineas 
94d13896dSAlexandros Lamprineas #include "llvm/Analysis/AssumptionCache.h"
104d13896dSAlexandros Lamprineas #include "llvm/Analysis/BlockFrequencyInfo.h"
114d13896dSAlexandros Lamprineas #include "llvm/Analysis/BranchProbabilityInfo.h"
124d13896dSAlexandros Lamprineas #include "llvm/Analysis/LoopInfo.h"
134d13896dSAlexandros Lamprineas #include "llvm/Analysis/PostDominators.h"
144d13896dSAlexandros Lamprineas #include "llvm/Analysis/TargetLibraryInfo.h"
154d13896dSAlexandros Lamprineas #include "llvm/Analysis/TargetTransformInfo.h"
164d13896dSAlexandros Lamprineas #include "llvm/AsmParser/Parser.h"
174d13896dSAlexandros Lamprineas #include "llvm/IR/Constants.h"
1836c6632eSNikita Popov #include "llvm/IR/PassInstrumentation.h"
194d13896dSAlexandros Lamprineas #include "llvm/Support/SourceMgr.h"
204d13896dSAlexandros Lamprineas #include "llvm/Transforms/IPO/FunctionSpecialization.h"
214d13896dSAlexandros Lamprineas #include "llvm/Transforms/Utils/SCCPSolver.h"
224d13896dSAlexandros Lamprineas #include "gtest/gtest.h"
234d13896dSAlexandros Lamprineas #include <memory>
244d13896dSAlexandros Lamprineas 
254d13896dSAlexandros Lamprineas namespace llvm {
264d13896dSAlexandros Lamprineas 
272e00eba2SAlexandros Lamprineas static void removeSSACopy(Function &F) {
282e00eba2SAlexandros Lamprineas   for (BasicBlock &BB : F) {
292e00eba2SAlexandros Lamprineas     for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
302e00eba2SAlexandros Lamprineas       if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
312e00eba2SAlexandros Lamprineas         if (II->getIntrinsicID() != Intrinsic::ssa_copy)
322e00eba2SAlexandros Lamprineas           continue;
332e00eba2SAlexandros Lamprineas         Inst.replaceAllUsesWith(II->getOperand(0));
342e00eba2SAlexandros Lamprineas         Inst.eraseFromParent();
352e00eba2SAlexandros Lamprineas       }
362e00eba2SAlexandros Lamprineas     }
372e00eba2SAlexandros Lamprineas   }
382e00eba2SAlexandros Lamprineas }
392e00eba2SAlexandros Lamprineas 
404d13896dSAlexandros Lamprineas class FunctionSpecializationTest : public testing::Test {
414d13896dSAlexandros Lamprineas protected:
424d13896dSAlexandros Lamprineas   LLVMContext Ctx;
434d13896dSAlexandros Lamprineas   FunctionAnalysisManager FAM;
444d13896dSAlexandros Lamprineas   std::unique_ptr<Module> M;
454d13896dSAlexandros Lamprineas   std::unique_ptr<SCCPSolver> Solver;
46c6931c25SHari Limaye   SmallVector<Instruction *, 8> KnownConstants;
474d13896dSAlexandros Lamprineas 
484d13896dSAlexandros Lamprineas   FunctionSpecializationTest() {
494d13896dSAlexandros Lamprineas     FAM.registerPass([&] { return TargetLibraryAnalysis(); });
504d13896dSAlexandros Lamprineas     FAM.registerPass([&] { return TargetIRAnalysis(); });
514d13896dSAlexandros Lamprineas     FAM.registerPass([&] { return BlockFrequencyAnalysis(); });
524d13896dSAlexandros Lamprineas     FAM.registerPass([&] { return BranchProbabilityAnalysis(); });
534d13896dSAlexandros Lamprineas     FAM.registerPass([&] { return LoopAnalysis(); });
544d13896dSAlexandros Lamprineas     FAM.registerPass([&] { return AssumptionAnalysis(); });
554d13896dSAlexandros Lamprineas     FAM.registerPass([&] { return DominatorTreeAnalysis(); });
564d13896dSAlexandros Lamprineas     FAM.registerPass([&] { return PostDominatorTreeAnalysis(); });
574d13896dSAlexandros Lamprineas     FAM.registerPass([&] { return PassInstrumentationAnalysis(); });
584d13896dSAlexandros Lamprineas   }
594d13896dSAlexandros Lamprineas 
604d13896dSAlexandros Lamprineas   Module &parseModule(const char *ModuleString) {
614d13896dSAlexandros Lamprineas     SMDiagnostic Err;
624d13896dSAlexandros Lamprineas     M = parseAssemblyString(ModuleString, Err, Ctx);
634d13896dSAlexandros Lamprineas     EXPECT_TRUE(M);
644d13896dSAlexandros Lamprineas     return *M;
654d13896dSAlexandros Lamprineas   }
664d13896dSAlexandros Lamprineas 
674d13896dSAlexandros Lamprineas   FunctionSpecializer getSpecializerFor(Function *F) {
684d13896dSAlexandros Lamprineas     auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & {
694d13896dSAlexandros Lamprineas       return FAM.getResult<TargetLibraryAnalysis>(F);
704d13896dSAlexandros Lamprineas     };
714d13896dSAlexandros Lamprineas     auto GetTTI = [this](Function &F) -> TargetTransformInfo & {
724d13896dSAlexandros Lamprineas       return FAM.getResult<TargetIRAnalysis>(F);
734d13896dSAlexandros Lamprineas     };
744d13896dSAlexandros Lamprineas     auto GetAC = [this](Function &F) -> AssumptionCache & {
754d13896dSAlexandros Lamprineas       return FAM.getResult<AssumptionAnalysis>(F);
764d13896dSAlexandros Lamprineas     };
774d13896dSAlexandros Lamprineas     auto GetDT = [this](Function &F) -> DominatorTree & {
784d13896dSAlexandros Lamprineas       return FAM.getResult<DominatorTreeAnalysis>(F);
794d13896dSAlexandros Lamprineas     };
804d13896dSAlexandros Lamprineas     auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & {
814d13896dSAlexandros Lamprineas       return FAM.getResult<BlockFrequencyAnalysis>(F);
824d13896dSAlexandros Lamprineas     };
834d13896dSAlexandros Lamprineas 
844d13896dSAlexandros Lamprineas     Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx);
854d13896dSAlexandros Lamprineas 
864d13896dSAlexandros Lamprineas     DominatorTree &DT = GetDT(*F);
874d13896dSAlexandros Lamprineas     AssumptionCache &AC = GetAC(*F);
884d13896dSAlexandros Lamprineas     Solver->addPredicateInfo(*F, DT, AC);
894d13896dSAlexandros Lamprineas 
904d13896dSAlexandros Lamprineas     Solver->markBlockExecutable(&F->front());
914d13896dSAlexandros Lamprineas     for (Argument &Arg : F->args())
924d13896dSAlexandros Lamprineas       Solver->markOverdefined(&Arg);
934d13896dSAlexandros Lamprineas     Solver->solveWhileResolvedUndefsIn(*M);
944d13896dSAlexandros Lamprineas 
952e00eba2SAlexandros Lamprineas     removeSSACopy(*F);
962e00eba2SAlexandros Lamprineas 
974d13896dSAlexandros Lamprineas     return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI,
984d13896dSAlexandros Lamprineas                                GetAC);
994d13896dSAlexandros Lamprineas   }
1004d13896dSAlexandros Lamprineas 
101c6931c25SHari Limaye   Cost getCodeSizeSavings(Instruction &I, bool HasLatencySavings = true) {
1024d13896dSAlexandros Lamprineas     auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction());
1034d13896dSAlexandros Lamprineas 
1045bfefff1SAlexandros Lamprineas     Cost CodeSize =
1055bfefff1SAlexandros Lamprineas         TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize);
1065bfefff1SAlexandros Lamprineas 
107c6931c25SHari Limaye     if (HasLatencySavings)
108c6931c25SHari Limaye       KnownConstants.push_back(&I);
1095bfefff1SAlexandros Lamprineas 
110c6931c25SHari Limaye     return CodeSize;
111c6931c25SHari Limaye   }
112c6931c25SHari Limaye 
113c6931c25SHari Limaye   Cost getLatencySavings(Function *F) {
114c6931c25SHari Limaye     auto &TTI = FAM.getResult<TargetIRAnalysis>(*F);
115c6931c25SHari Limaye     auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*F);
116c6931c25SHari Limaye 
117c6931c25SHari Limaye     Cost Latency = 0;
118c6931c25SHari Limaye     for (const Instruction *I : KnownConstants)
119c6931c25SHari Limaye       Latency += BFI.getBlockFreq(I->getParent()).getFrequency() /
120c6931c25SHari Limaye                  BFI.getEntryFreq().getFrequency() *
121c6931c25SHari Limaye                  TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency);
122c6931c25SHari Limaye 
123c6931c25SHari Limaye     return Latency;
1244d13896dSAlexandros Lamprineas   }
1254d13896dSAlexandros Lamprineas };
1264d13896dSAlexandros Lamprineas 
1274d13896dSAlexandros Lamprineas } // namespace llvm
1284d13896dSAlexandros Lamprineas 
1294d13896dSAlexandros Lamprineas using namespace llvm;
1304d13896dSAlexandros Lamprineas 
1314d13896dSAlexandros Lamprineas TEST_F(FunctionSpecializationTest, SwitchInst) {
1324d13896dSAlexandros Lamprineas   const char *ModuleString = R"(
1334d13896dSAlexandros Lamprineas     define void @foo(i32 %a, i32 %b, i32 %i) {
1344d13896dSAlexandros Lamprineas     entry:
135f11d8c88SAlexandros Lamprineas       br label %loop
136f11d8c88SAlexandros Lamprineas     loop:
1374d13896dSAlexandros Lamprineas       switch i32 %i, label %default
1384d13896dSAlexandros Lamprineas       [ i32 1, label %case1
1394d13896dSAlexandros Lamprineas         i32 2, label %case2 ]
1404d13896dSAlexandros Lamprineas     case1:
1414d13896dSAlexandros Lamprineas       %0 = mul i32 %a, 2
1424d13896dSAlexandros Lamprineas       %1 = sub i32 6, 5
1434d13896dSAlexandros Lamprineas       br label %bb1
1444d13896dSAlexandros Lamprineas     case2:
1454d13896dSAlexandros Lamprineas       %2 = and i32 %b, 3
1464d13896dSAlexandros Lamprineas       %3 = sdiv i32 8, 2
1474d13896dSAlexandros Lamprineas       br label %bb2
1484d13896dSAlexandros Lamprineas     bb1:
1494d13896dSAlexandros Lamprineas       %4 = add i32 %0, %b
150f11d8c88SAlexandros Lamprineas       br label %loop
1514d13896dSAlexandros Lamprineas     bb2:
1524d13896dSAlexandros Lamprineas       %5 = or i32 %2, %a
153f11d8c88SAlexandros Lamprineas       br label %loop
1544d13896dSAlexandros Lamprineas     default:
1554d13896dSAlexandros Lamprineas       ret void
1564d13896dSAlexandros Lamprineas     }
1574d13896dSAlexandros Lamprineas   )";
1584d13896dSAlexandros Lamprineas 
1594d13896dSAlexandros Lamprineas   Module &M = parseModule(ModuleString);
1604d13896dSAlexandros Lamprineas   Function *F = M.getFunction("foo");
1614d13896dSAlexandros Lamprineas   FunctionSpecializer Specializer = getSpecializerFor(F);
1624d13896dSAlexandros Lamprineas   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
1634d13896dSAlexandros Lamprineas 
1644d13896dSAlexandros Lamprineas   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
1654d13896dSAlexandros Lamprineas 
1664d13896dSAlexandros Lamprineas   auto FuncIter = F->begin();
1675bfefff1SAlexandros Lamprineas   BasicBlock &Loop = *++FuncIter;
1684d13896dSAlexandros Lamprineas   BasicBlock &Case1 = *++FuncIter;
1694d13896dSAlexandros Lamprineas   BasicBlock &Case2 = *++FuncIter;
1704d13896dSAlexandros Lamprineas   BasicBlock &BB1 = *++FuncIter;
1714d13896dSAlexandros Lamprineas   BasicBlock &BB2 = *++FuncIter;
1724d13896dSAlexandros Lamprineas 
1735bfefff1SAlexandros Lamprineas   Instruction &Switch = Loop.front();
1744d13896dSAlexandros Lamprineas   Instruction &Mul = Case1.front();
1754d13896dSAlexandros Lamprineas   Instruction &And = Case2.front();
1764d13896dSAlexandros Lamprineas   Instruction &Sdiv = *++Case2.begin();
1774d13896dSAlexandros Lamprineas   Instruction &BrBB2 = Case2.back();
1784d13896dSAlexandros Lamprineas   Instruction &Add = BB1.front();
1794d13896dSAlexandros Lamprineas   Instruction &Or = BB2.front();
180f11d8c88SAlexandros Lamprineas   Instruction &BrLoop = BB2.back();
1814d13896dSAlexandros Lamprineas 
1824d13896dSAlexandros Lamprineas   // mul
183c6931c25SHari Limaye   Cost Ref = getCodeSizeSavings(Mul);
184c6931c25SHari Limaye   Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
1855bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
186c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
1874d13896dSAlexandros Lamprineas 
1884d13896dSAlexandros Lamprineas   // and + or + add
189c6931c25SHari Limaye   Ref = getCodeSizeSavings(And) + getCodeSizeSavings(Or) +
190c6931c25SHari Limaye         getCodeSizeSavings(Add);
191c6931c25SHari Limaye   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
1925bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
193c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
1944d13896dSAlexandros Lamprineas 
1955bfefff1SAlexandros Lamprineas   // switch + sdiv + br + br
196c6931c25SHari Limaye   Ref = getCodeSizeSavings(Switch) +
197c6931c25SHari Limaye         getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) +
198c6931c25SHari Limaye         getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) +
199c6931c25SHari Limaye         getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false);
200c6931c25SHari Limaye   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One);
2015bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
202c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
203c6931c25SHari Limaye 
204c6931c25SHari Limaye   // Latency.
205c6931c25SHari Limaye   Ref = getLatencySavings(F);
206c6931c25SHari Limaye   Test = Visitor.getLatencySavingsForKnownConstants();
207c6931c25SHari Limaye   EXPECT_EQ(Test, Ref);
208c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
2094d13896dSAlexandros Lamprineas }
2104d13896dSAlexandros Lamprineas 
2114d13896dSAlexandros Lamprineas TEST_F(FunctionSpecializationTest, BranchInst) {
2124d13896dSAlexandros Lamprineas   const char *ModuleString = R"(
2134d13896dSAlexandros Lamprineas     define void @foo(i32 %a, i32 %b, i1 %cond) {
2144d13896dSAlexandros Lamprineas     entry:
215f11d8c88SAlexandros Lamprineas       br label %loop
216f11d8c88SAlexandros Lamprineas     loop:
217c2d19002SAlexandros Lamprineas       br i1 %cond, label %bb0, label %bb3
2184d13896dSAlexandros Lamprineas     bb0:
2194d13896dSAlexandros Lamprineas       %0 = mul i32 %a, 2
2204d13896dSAlexandros Lamprineas       %1 = sub i32 6, 5
221c2d19002SAlexandros Lamprineas       br i1 %cond, label %bb1, label %bb2
2224d13896dSAlexandros Lamprineas     bb1:
2234d13896dSAlexandros Lamprineas       %2 = add i32 %0, %b
2244d13896dSAlexandros Lamprineas       %3 = sdiv i32 8, 2
225c2d19002SAlexandros Lamprineas       br label %bb2
2264d13896dSAlexandros Lamprineas     bb2:
227c2d19002SAlexandros Lamprineas       br label %loop
228c2d19002SAlexandros Lamprineas     bb3:
2294d13896dSAlexandros Lamprineas       ret void
2304d13896dSAlexandros Lamprineas     }
2314d13896dSAlexandros Lamprineas   )";
2324d13896dSAlexandros Lamprineas 
2334d13896dSAlexandros Lamprineas   Module &M = parseModule(ModuleString);
2344d13896dSAlexandros Lamprineas   Function *F = M.getFunction("foo");
2354d13896dSAlexandros Lamprineas   FunctionSpecializer Specializer = getSpecializerFor(F);
2364d13896dSAlexandros Lamprineas   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
2374d13896dSAlexandros Lamprineas 
2384d13896dSAlexandros Lamprineas   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
2394d13896dSAlexandros Lamprineas   Constant *False = ConstantInt::getFalse(M.getContext());
2404d13896dSAlexandros Lamprineas 
2414d13896dSAlexandros Lamprineas   auto FuncIter = F->begin();
2425bfefff1SAlexandros Lamprineas   BasicBlock &Loop = *++FuncIter;
2434d13896dSAlexandros Lamprineas   BasicBlock &BB0 = *++FuncIter;
2444d13896dSAlexandros Lamprineas   BasicBlock &BB1 = *++FuncIter;
245c2d19002SAlexandros Lamprineas   BasicBlock &BB2 = *++FuncIter;
2464d13896dSAlexandros Lamprineas 
2475bfefff1SAlexandros Lamprineas   Instruction &Branch = Loop.front();
2484d13896dSAlexandros Lamprineas   Instruction &Mul = BB0.front();
2494d13896dSAlexandros Lamprineas   Instruction &Sub = *++BB0.begin();
250c2d19002SAlexandros Lamprineas   Instruction &BrBB1BB2 = BB0.back();
2514d13896dSAlexandros Lamprineas   Instruction &Add = BB1.front();
2524d13896dSAlexandros Lamprineas   Instruction &Sdiv = *++BB1.begin();
253c2d19002SAlexandros Lamprineas   Instruction &BrBB2 = BB1.back();
254c2d19002SAlexandros Lamprineas   Instruction &BrLoop = BB2.front();
2554d13896dSAlexandros Lamprineas 
2564d13896dSAlexandros Lamprineas   // mul
257c6931c25SHari Limaye   Cost Ref = getCodeSizeSavings(Mul);
258c6931c25SHari Limaye   Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
2595bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
260c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
2614d13896dSAlexandros Lamprineas 
2624d13896dSAlexandros Lamprineas   // add
263c6931c25SHari Limaye   Ref = getCodeSizeSavings(Add);
264c6931c25SHari Limaye   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
2655bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
266c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
2674d13896dSAlexandros Lamprineas 
2685bfefff1SAlexandros Lamprineas   // branch + sub + br + sdiv + br
269c6931c25SHari Limaye   Ref = getCodeSizeSavings(Branch) +
270c6931c25SHari Limaye         getCodeSizeSavings(Sub, /*HasLatencySavings=*/false) +
271c6931c25SHari Limaye         getCodeSizeSavings(BrBB1BB2) +
272c6931c25SHari Limaye         getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) +
273c6931c25SHari Limaye         getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) +
274c6931c25SHari Limaye         getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false);
275c6931c25SHari Limaye   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), False);
2765bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
277c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
278c6931c25SHari Limaye 
279c6931c25SHari Limaye   // Latency.
280c6931c25SHari Limaye   Ref = getLatencySavings(F);
281c6931c25SHari Limaye   Test = Visitor.getLatencySavingsForKnownConstants();
282c6931c25SHari Limaye   EXPECT_EQ(Test, Ref);
283c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
2844d13896dSAlexandros Lamprineas }
2854d13896dSAlexandros Lamprineas 
2866472cb1eSAlexandros Lamprineas TEST_F(FunctionSpecializationTest, SelectInst) {
2876472cb1eSAlexandros Lamprineas   const char *ModuleString = R"(
2886472cb1eSAlexandros Lamprineas     define i32 @foo(i1 %cond, i32 %a, i32 %b) {
2896472cb1eSAlexandros Lamprineas       %sel = select i1 %cond, i32 %a, i32 %b
2906472cb1eSAlexandros Lamprineas       ret i32 %sel
2916472cb1eSAlexandros Lamprineas     }
2926472cb1eSAlexandros Lamprineas   )";
2936472cb1eSAlexandros Lamprineas 
2946472cb1eSAlexandros Lamprineas   Module &M = parseModule(ModuleString);
2956472cb1eSAlexandros Lamprineas   Function *F = M.getFunction("foo");
2966472cb1eSAlexandros Lamprineas   FunctionSpecializer Specializer = getSpecializerFor(F);
2976472cb1eSAlexandros Lamprineas   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
2986472cb1eSAlexandros Lamprineas 
2996472cb1eSAlexandros Lamprineas   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
3006472cb1eSAlexandros Lamprineas   Constant *Zero = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 0);
3016472cb1eSAlexandros Lamprineas   Constant *False = ConstantInt::getFalse(M.getContext());
3026472cb1eSAlexandros Lamprineas   Instruction &Select = *F->front().begin();
3036472cb1eSAlexandros Lamprineas 
304c6931c25SHari Limaye   Cost RefCodeSize = getCodeSizeSavings(Select);
305c6931c25SHari Limaye   Cost RefLatency = getLatencySavings(F);
306c6931c25SHari Limaye 
307c6931c25SHari Limaye   Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False);
308c6931c25SHari Limaye   EXPECT_TRUE(TestCodeSize == 0);
309c6931c25SHari Limaye   TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
310c6931c25SHari Limaye   EXPECT_TRUE(TestCodeSize == 0);
311c6931c25SHari Limaye   Cost TestLatency = Visitor.getLatencySavingsForKnownConstants();
312c6931c25SHari Limaye   EXPECT_TRUE(TestLatency == 0);
313c6931c25SHari Limaye 
314c6931c25SHari Limaye   TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(2), Zero);
315c6931c25SHari Limaye   EXPECT_EQ(TestCodeSize, RefCodeSize);
316c6931c25SHari Limaye   EXPECT_TRUE(TestCodeSize > 0);
317c6931c25SHari Limaye   TestLatency = Visitor.getLatencySavingsForKnownConstants();
318c6931c25SHari Limaye   EXPECT_EQ(TestLatency, RefLatency);
319c6931c25SHari Limaye   EXPECT_TRUE(TestLatency > 0);
3206472cb1eSAlexandros Lamprineas }
3216472cb1eSAlexandros Lamprineas 
3224d13896dSAlexandros Lamprineas TEST_F(FunctionSpecializationTest, Misc) {
3234d13896dSAlexandros Lamprineas   const char *ModuleString = R"(
324cae00b2aSAlexandros Lamprineas     %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] }
325cae00b2aSAlexandros Lamprineas     @g = constant %struct_t zeroinitializer, align 16
3264d13896dSAlexandros Lamprineas 
3275400257dSAlexandros Lamprineas     declare i32 @llvm.smax.i32(i32, i32)
3285400257dSAlexandros Lamprineas     declare i32 @bar(i32)
3295400257dSAlexandros Lamprineas 
3305400257dSAlexandros Lamprineas     define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) {
3314d13896dSAlexandros Lamprineas       %cmp = icmp eq i8 %a, 10
332cae00b2aSAlexandros Lamprineas       %ext = zext i1 %cmp to i64
333cae00b2aSAlexandros Lamprineas       %sel = select i1 %cond, i64 %ext, i64 1
334cae00b2aSAlexandros Lamprineas       %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4
3354d13896dSAlexandros Lamprineas       %ld = load i32, ptr %gep
3365400257dSAlexandros Lamprineas       %fr = freeze i32 %ld
3375400257dSAlexandros Lamprineas       %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1)
3385400257dSAlexandros Lamprineas       %call = call i32 @bar(i32 %smax)
3395400257dSAlexandros Lamprineas       %fr2 = freeze i32 %c
3405400257dSAlexandros Lamprineas       %add = add i32 %call, %fr2
3415400257dSAlexandros Lamprineas       ret i32 %add
3424d13896dSAlexandros Lamprineas     }
3434d13896dSAlexandros Lamprineas   )";
3444d13896dSAlexandros Lamprineas 
3454d13896dSAlexandros Lamprineas   Module &M = parseModule(ModuleString);
3464d13896dSAlexandros Lamprineas   Function *F = M.getFunction("foo");
3474d13896dSAlexandros Lamprineas   FunctionSpecializer Specializer = getSpecializerFor(F);
3484d13896dSAlexandros Lamprineas   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
3494d13896dSAlexandros Lamprineas 
3504d13896dSAlexandros Lamprineas   GlobalVariable *GV = M.getGlobalVariable("g");
3514d13896dSAlexandros Lamprineas   Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1);
3524d13896dSAlexandros Lamprineas   Constant *True = ConstantInt::getTrue(M.getContext());
3535400257dSAlexandros Lamprineas   Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext()));
3544d13896dSAlexandros Lamprineas 
3554d13896dSAlexandros Lamprineas   auto BlockIter = F->front().begin();
3564d13896dSAlexandros Lamprineas   Instruction &Icmp = *BlockIter++;
3574d13896dSAlexandros Lamprineas   Instruction &Zext = *BlockIter++;
3584d13896dSAlexandros Lamprineas   Instruction &Select = *BlockIter++;
3594d13896dSAlexandros Lamprineas   Instruction &Gep = *BlockIter++;
3604d13896dSAlexandros Lamprineas   Instruction &Load = *BlockIter++;
3615400257dSAlexandros Lamprineas   Instruction &Freeze = *BlockIter++;
3625400257dSAlexandros Lamprineas   Instruction &Smax = *BlockIter++;
3634d13896dSAlexandros Lamprineas 
3644d13896dSAlexandros Lamprineas   // icmp + zext
365c6931c25SHari Limaye   Cost Ref = getCodeSizeSavings(Icmp) + getCodeSizeSavings(Zext);
366c6931c25SHari Limaye   Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
3675bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
368c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
3694d13896dSAlexandros Lamprineas 
3704d13896dSAlexandros Lamprineas   // select
371c6931c25SHari Limaye   Ref = getCodeSizeSavings(Select);
372c6931c25SHari Limaye   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), True);
3735bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
374c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
3754d13896dSAlexandros Lamprineas 
3765400257dSAlexandros Lamprineas   // gep + load + freeze + smax
377c6931c25SHari Limaye   Ref = getCodeSizeSavings(Gep) + getCodeSizeSavings(Load) +
378c6931c25SHari Limaye         getCodeSizeSavings(Freeze) + getCodeSizeSavings(Smax);
379c6931c25SHari Limaye   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), GV);
3805bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
381c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
3825400257dSAlexandros Lamprineas 
383c6931c25SHari Limaye   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(3), Undef);
384c6931c25SHari Limaye   EXPECT_TRUE(Test == 0);
385c6931c25SHari Limaye 
386c6931c25SHari Limaye   // Latency.
387c6931c25SHari Limaye   Ref = getLatencySavings(F);
388c6931c25SHari Limaye   Test = Visitor.getLatencySavingsForKnownConstants();
389c6931c25SHari Limaye   EXPECT_EQ(Test, Ref);
390c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
3914d13896dSAlexandros Lamprineas }
392893d3a61SAlexandros Lamprineas 
393893d3a61SAlexandros Lamprineas TEST_F(FunctionSpecializationTest, PhiNode) {
394893d3a61SAlexandros Lamprineas   const char *ModuleString = R"(
395893d3a61SAlexandros Lamprineas     define void @foo(i32 %a, i32 %b, i32 %i) {
396893d3a61SAlexandros Lamprineas     entry:
397893d3a61SAlexandros Lamprineas       br label %loop
398893d3a61SAlexandros Lamprineas     loop:
399893d3a61SAlexandros Lamprineas       %0 = phi i32 [ %a, %entry ], [ %3, %bb ]
400893d3a61SAlexandros Lamprineas       switch i32 %i, label %default
401893d3a61SAlexandros Lamprineas       [ i32 1, label %case1
402893d3a61SAlexandros Lamprineas         i32 2, label %case2 ]
403893d3a61SAlexandros Lamprineas     case1:
404893d3a61SAlexandros Lamprineas       %1 = add i32 %0, 1
405893d3a61SAlexandros Lamprineas       br label %bb
406893d3a61SAlexandros Lamprineas     case2:
407893d3a61SAlexandros Lamprineas       %2 = phi i32 [ %a, %entry ], [ %0, %loop ]
408893d3a61SAlexandros Lamprineas       br label %bb
409893d3a61SAlexandros Lamprineas     bb:
410893d3a61SAlexandros Lamprineas       %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ]
411893d3a61SAlexandros Lamprineas       %4 = icmp eq i32 %3, 1
412893d3a61SAlexandros Lamprineas       br i1 %4, label %bb, label %loop
413893d3a61SAlexandros Lamprineas     default:
414893d3a61SAlexandros Lamprineas       ret void
415893d3a61SAlexandros Lamprineas     }
416893d3a61SAlexandros Lamprineas   )";
417893d3a61SAlexandros Lamprineas 
418893d3a61SAlexandros Lamprineas   Module &M = parseModule(ModuleString);
419893d3a61SAlexandros Lamprineas   Function *F = M.getFunction("foo");
420893d3a61SAlexandros Lamprineas   FunctionSpecializer Specializer = getSpecializerFor(F);
421893d3a61SAlexandros Lamprineas   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
422893d3a61SAlexandros Lamprineas 
423893d3a61SAlexandros Lamprineas   Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1);
424893d3a61SAlexandros Lamprineas 
425893d3a61SAlexandros Lamprineas   auto FuncIter = F->begin();
426893d3a61SAlexandros Lamprineas   BasicBlock &Loop = *++FuncIter;
427893d3a61SAlexandros Lamprineas   BasicBlock &Case1 = *++FuncIter;
428893d3a61SAlexandros Lamprineas   BasicBlock &Case2 = *++FuncIter;
429893d3a61SAlexandros Lamprineas   BasicBlock &BB = *++FuncIter;
430893d3a61SAlexandros Lamprineas 
431893d3a61SAlexandros Lamprineas   Instruction &PhiLoop = Loop.front();
4325bfefff1SAlexandros Lamprineas   Instruction &Switch = Loop.back();
433893d3a61SAlexandros Lamprineas   Instruction &Add = Case1.front();
434893d3a61SAlexandros Lamprineas   Instruction &PhiCase2 = Case2.front();
435893d3a61SAlexandros Lamprineas   Instruction &BrBB = Case2.back();
436893d3a61SAlexandros Lamprineas   Instruction &PhiBB = BB.front();
437893d3a61SAlexandros Lamprineas   Instruction &Icmp = *++BB.begin();
4385bfefff1SAlexandros Lamprineas   Instruction &Branch = BB.back();
439893d3a61SAlexandros Lamprineas 
440c6931c25SHari Limaye   Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One);
441c6931c25SHari Limaye   EXPECT_TRUE(Test == 0);
442893d3a61SAlexandros Lamprineas 
443c6931c25SHari Limaye   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One);
444c6931c25SHari Limaye   EXPECT_TRUE(Test == 0);
445c6931c25SHari Limaye 
446c6931c25SHari Limaye   Test = Visitor.getLatencySavingsForKnownConstants();
447c6931c25SHari Limaye   EXPECT_TRUE(Test == 0);
448893d3a61SAlexandros Lamprineas 
4495bfefff1SAlexandros Lamprineas   // switch + phi + br
450c6931c25SHari Limaye   Cost Ref = getCodeSizeSavings(Switch) +
451c6931c25SHari Limaye              getCodeSizeSavings(PhiCase2, /*HasLatencySavings=*/false) +
452c6931c25SHari Limaye              getCodeSizeSavings(BrBB, /*HasLatencySavings=*/false);
453c6931c25SHari Limaye   Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One);
4545bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
455c6931c25SHari Limaye   EXPECT_TRUE(Test > 0 && Test > 0);
456893d3a61SAlexandros Lamprineas 
4575bfefff1SAlexandros Lamprineas   // phi + phi + add + icmp + branch
458c6931c25SHari Limaye   Ref = getCodeSizeSavings(PhiBB) + getCodeSizeSavings(PhiLoop) +
459c6931c25SHari Limaye         getCodeSizeSavings(Add) + getCodeSizeSavings(Icmp) +
460c6931c25SHari Limaye         getCodeSizeSavings(Branch);
461c6931c25SHari Limaye   Test = Visitor.getCodeSizeSavingsFromPendingPHIs();
4625bfefff1SAlexandros Lamprineas   EXPECT_EQ(Test, Ref);
463c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
464c6931c25SHari Limaye 
465c6931c25SHari Limaye   // Latency.
466c6931c25SHari Limaye   Ref = getLatencySavings(F);
467c6931c25SHari Limaye   Test = Visitor.getLatencySavingsForKnownConstants();
468c6931c25SHari Limaye   EXPECT_EQ(Test, Ref);
469c6931c25SHari Limaye   EXPECT_TRUE(Test > 0);
470893d3a61SAlexandros Lamprineas }
471893d3a61SAlexandros Lamprineas 
472*5f30b1aaSHari Limaye TEST_F(FunctionSpecializationTest, BinOp) {
473*5f30b1aaSHari Limaye   // Verify that we can handle binary operators even when only one operand is
474*5f30b1aaSHari Limaye   // constant.
475*5f30b1aaSHari Limaye   const char *ModuleString = R"(
476*5f30b1aaSHari Limaye     define i32 @foo(i1 %a, i1 %b) {
477*5f30b1aaSHari Limaye       %and1 = and i1 %a, %b
478*5f30b1aaSHari Limaye       %and2 = and i1 %b, %and1
479*5f30b1aaSHari Limaye       %sel = select i1 %and2, i32 1, i32 0
480*5f30b1aaSHari Limaye       ret i32 %sel
481*5f30b1aaSHari Limaye     }
482*5f30b1aaSHari Limaye   )";
483*5f30b1aaSHari Limaye 
484*5f30b1aaSHari Limaye   Module &M = parseModule(ModuleString);
485*5f30b1aaSHari Limaye   Function *F = M.getFunction("foo");
486*5f30b1aaSHari Limaye   FunctionSpecializer Specializer = getSpecializerFor(F);
487*5f30b1aaSHari Limaye   InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F);
488*5f30b1aaSHari Limaye 
489*5f30b1aaSHari Limaye   Constant *False = ConstantInt::getFalse(M.getContext());
490*5f30b1aaSHari Limaye   BasicBlock &BB = F->front();
491*5f30b1aaSHari Limaye   Instruction &And1 = BB.front();
492*5f30b1aaSHari Limaye   Instruction &And2 = *++BB.begin();
493*5f30b1aaSHari Limaye   Instruction &Select = *++BB.begin();
494*5f30b1aaSHari Limaye 
495*5f30b1aaSHari Limaye   Cost RefCodeSize = getCodeSizeSavings(And1) + getCodeSizeSavings(And2) +
496*5f30b1aaSHari Limaye                      getCodeSizeSavings(Select);
497*5f30b1aaSHari Limaye   Cost RefLatency = getLatencySavings(F);
498*5f30b1aaSHari Limaye 
499*5f30b1aaSHari Limaye   Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False);
500*5f30b1aaSHari Limaye   Cost TestLatency = Visitor.getLatencySavingsForKnownConstants();
501*5f30b1aaSHari Limaye 
502*5f30b1aaSHari Limaye   EXPECT_EQ(TestCodeSize, RefCodeSize);
503*5f30b1aaSHari Limaye   EXPECT_TRUE(TestCodeSize > 0);
504*5f30b1aaSHari Limaye   EXPECT_EQ(TestLatency, RefLatency);
505*5f30b1aaSHari Limaye   EXPECT_TRUE(TestLatency > 0);
506*5f30b1aaSHari Limaye }
507