14d13896dSAlexandros Lamprineas //===- FunctionSpecializationTest.cpp - Cost model unit tests -------------===// 24d13896dSAlexandros Lamprineas // 34d13896dSAlexandros Lamprineas // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 44d13896dSAlexandros Lamprineas // See https://llvm.org/LICENSE.txt for license information. 54d13896dSAlexandros Lamprineas // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 64d13896dSAlexandros Lamprineas // 74d13896dSAlexandros Lamprineas //===----------------------------------------------------------------------===// 84d13896dSAlexandros Lamprineas 94d13896dSAlexandros Lamprineas #include "llvm/Analysis/AssumptionCache.h" 104d13896dSAlexandros Lamprineas #include "llvm/Analysis/BlockFrequencyInfo.h" 114d13896dSAlexandros Lamprineas #include "llvm/Analysis/BranchProbabilityInfo.h" 124d13896dSAlexandros Lamprineas #include "llvm/Analysis/LoopInfo.h" 134d13896dSAlexandros Lamprineas #include "llvm/Analysis/PostDominators.h" 144d13896dSAlexandros Lamprineas #include "llvm/Analysis/TargetLibraryInfo.h" 154d13896dSAlexandros Lamprineas #include "llvm/Analysis/TargetTransformInfo.h" 164d13896dSAlexandros Lamprineas #include "llvm/AsmParser/Parser.h" 174d13896dSAlexandros Lamprineas #include "llvm/IR/Constants.h" 1836c6632eSNikita Popov #include "llvm/IR/PassInstrumentation.h" 194d13896dSAlexandros Lamprineas #include "llvm/Support/SourceMgr.h" 204d13896dSAlexandros Lamprineas #include "llvm/Transforms/IPO/FunctionSpecialization.h" 214d13896dSAlexandros Lamprineas #include "llvm/Transforms/Utils/SCCPSolver.h" 224d13896dSAlexandros Lamprineas #include "gtest/gtest.h" 234d13896dSAlexandros Lamprineas #include <memory> 244d13896dSAlexandros Lamprineas 254d13896dSAlexandros Lamprineas namespace llvm { 264d13896dSAlexandros Lamprineas 272e00eba2SAlexandros Lamprineas static void removeSSACopy(Function &F) { 282e00eba2SAlexandros Lamprineas for (BasicBlock &BB : F) { 292e00eba2SAlexandros Lamprineas for (Instruction &Inst : llvm::make_early_inc_range(BB)) { 302e00eba2SAlexandros Lamprineas if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) { 312e00eba2SAlexandros Lamprineas if (II->getIntrinsicID() != Intrinsic::ssa_copy) 322e00eba2SAlexandros Lamprineas continue; 332e00eba2SAlexandros Lamprineas Inst.replaceAllUsesWith(II->getOperand(0)); 342e00eba2SAlexandros Lamprineas Inst.eraseFromParent(); 352e00eba2SAlexandros Lamprineas } 362e00eba2SAlexandros Lamprineas } 372e00eba2SAlexandros Lamprineas } 382e00eba2SAlexandros Lamprineas } 392e00eba2SAlexandros Lamprineas 404d13896dSAlexandros Lamprineas class FunctionSpecializationTest : public testing::Test { 414d13896dSAlexandros Lamprineas protected: 424d13896dSAlexandros Lamprineas LLVMContext Ctx; 434d13896dSAlexandros Lamprineas FunctionAnalysisManager FAM; 444d13896dSAlexandros Lamprineas std::unique_ptr<Module> M; 454d13896dSAlexandros Lamprineas std::unique_ptr<SCCPSolver> Solver; 46c6931c25SHari Limaye SmallVector<Instruction *, 8> KnownConstants; 474d13896dSAlexandros Lamprineas 484d13896dSAlexandros Lamprineas FunctionSpecializationTest() { 494d13896dSAlexandros Lamprineas FAM.registerPass([&] { return TargetLibraryAnalysis(); }); 504d13896dSAlexandros Lamprineas FAM.registerPass([&] { return TargetIRAnalysis(); }); 514d13896dSAlexandros Lamprineas FAM.registerPass([&] { return BlockFrequencyAnalysis(); }); 524d13896dSAlexandros Lamprineas FAM.registerPass([&] { return BranchProbabilityAnalysis(); }); 534d13896dSAlexandros Lamprineas FAM.registerPass([&] { return LoopAnalysis(); }); 544d13896dSAlexandros Lamprineas FAM.registerPass([&] { return AssumptionAnalysis(); }); 554d13896dSAlexandros Lamprineas FAM.registerPass([&] { return DominatorTreeAnalysis(); }); 564d13896dSAlexandros Lamprineas FAM.registerPass([&] { return PostDominatorTreeAnalysis(); }); 574d13896dSAlexandros Lamprineas FAM.registerPass([&] { return PassInstrumentationAnalysis(); }); 584d13896dSAlexandros Lamprineas } 594d13896dSAlexandros Lamprineas 604d13896dSAlexandros Lamprineas Module &parseModule(const char *ModuleString) { 614d13896dSAlexandros Lamprineas SMDiagnostic Err; 624d13896dSAlexandros Lamprineas M = parseAssemblyString(ModuleString, Err, Ctx); 634d13896dSAlexandros Lamprineas EXPECT_TRUE(M); 644d13896dSAlexandros Lamprineas return *M; 654d13896dSAlexandros Lamprineas } 664d13896dSAlexandros Lamprineas 674d13896dSAlexandros Lamprineas FunctionSpecializer getSpecializerFor(Function *F) { 684d13896dSAlexandros Lamprineas auto GetTLI = [this](Function &F) -> const TargetLibraryInfo & { 694d13896dSAlexandros Lamprineas return FAM.getResult<TargetLibraryAnalysis>(F); 704d13896dSAlexandros Lamprineas }; 714d13896dSAlexandros Lamprineas auto GetTTI = [this](Function &F) -> TargetTransformInfo & { 724d13896dSAlexandros Lamprineas return FAM.getResult<TargetIRAnalysis>(F); 734d13896dSAlexandros Lamprineas }; 744d13896dSAlexandros Lamprineas auto GetAC = [this](Function &F) -> AssumptionCache & { 754d13896dSAlexandros Lamprineas return FAM.getResult<AssumptionAnalysis>(F); 764d13896dSAlexandros Lamprineas }; 774d13896dSAlexandros Lamprineas auto GetDT = [this](Function &F) -> DominatorTree & { 784d13896dSAlexandros Lamprineas return FAM.getResult<DominatorTreeAnalysis>(F); 794d13896dSAlexandros Lamprineas }; 804d13896dSAlexandros Lamprineas auto GetBFI = [this](Function &F) -> BlockFrequencyInfo & { 814d13896dSAlexandros Lamprineas return FAM.getResult<BlockFrequencyAnalysis>(F); 824d13896dSAlexandros Lamprineas }; 834d13896dSAlexandros Lamprineas 844d13896dSAlexandros Lamprineas Solver = std::make_unique<SCCPSolver>(M->getDataLayout(), GetTLI, Ctx); 854d13896dSAlexandros Lamprineas 864d13896dSAlexandros Lamprineas DominatorTree &DT = GetDT(*F); 874d13896dSAlexandros Lamprineas AssumptionCache &AC = GetAC(*F); 884d13896dSAlexandros Lamprineas Solver->addPredicateInfo(*F, DT, AC); 894d13896dSAlexandros Lamprineas 904d13896dSAlexandros Lamprineas Solver->markBlockExecutable(&F->front()); 914d13896dSAlexandros Lamprineas for (Argument &Arg : F->args()) 924d13896dSAlexandros Lamprineas Solver->markOverdefined(&Arg); 934d13896dSAlexandros Lamprineas Solver->solveWhileResolvedUndefsIn(*M); 944d13896dSAlexandros Lamprineas 952e00eba2SAlexandros Lamprineas removeSSACopy(*F); 962e00eba2SAlexandros Lamprineas 974d13896dSAlexandros Lamprineas return FunctionSpecializer(*Solver, *M, &FAM, GetBFI, GetTLI, GetTTI, 984d13896dSAlexandros Lamprineas GetAC); 994d13896dSAlexandros Lamprineas } 1004d13896dSAlexandros Lamprineas 101c6931c25SHari Limaye Cost getCodeSizeSavings(Instruction &I, bool HasLatencySavings = true) { 1024d13896dSAlexandros Lamprineas auto &TTI = FAM.getResult<TargetIRAnalysis>(*I.getFunction()); 1034d13896dSAlexandros Lamprineas 1045bfefff1SAlexandros Lamprineas Cost CodeSize = 1055bfefff1SAlexandros Lamprineas TTI.getInstructionCost(&I, TargetTransformInfo::TCK_CodeSize); 1065bfefff1SAlexandros Lamprineas 107c6931c25SHari Limaye if (HasLatencySavings) 108c6931c25SHari Limaye KnownConstants.push_back(&I); 1095bfefff1SAlexandros Lamprineas 110c6931c25SHari Limaye return CodeSize; 111c6931c25SHari Limaye } 112c6931c25SHari Limaye 113c6931c25SHari Limaye Cost getLatencySavings(Function *F) { 114c6931c25SHari Limaye auto &TTI = FAM.getResult<TargetIRAnalysis>(*F); 115c6931c25SHari Limaye auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(*F); 116c6931c25SHari Limaye 117c6931c25SHari Limaye Cost Latency = 0; 118c6931c25SHari Limaye for (const Instruction *I : KnownConstants) 119c6931c25SHari Limaye Latency += BFI.getBlockFreq(I->getParent()).getFrequency() / 120c6931c25SHari Limaye BFI.getEntryFreq().getFrequency() * 121c6931c25SHari Limaye TTI.getInstructionCost(I, TargetTransformInfo::TCK_Latency); 122c6931c25SHari Limaye 123c6931c25SHari Limaye return Latency; 1244d13896dSAlexandros Lamprineas } 1254d13896dSAlexandros Lamprineas }; 1264d13896dSAlexandros Lamprineas 1274d13896dSAlexandros Lamprineas } // namespace llvm 1284d13896dSAlexandros Lamprineas 1294d13896dSAlexandros Lamprineas using namespace llvm; 1304d13896dSAlexandros Lamprineas 1314d13896dSAlexandros Lamprineas TEST_F(FunctionSpecializationTest, SwitchInst) { 1324d13896dSAlexandros Lamprineas const char *ModuleString = R"( 1334d13896dSAlexandros Lamprineas define void @foo(i32 %a, i32 %b, i32 %i) { 1344d13896dSAlexandros Lamprineas entry: 135f11d8c88SAlexandros Lamprineas br label %loop 136f11d8c88SAlexandros Lamprineas loop: 1374d13896dSAlexandros Lamprineas switch i32 %i, label %default 1384d13896dSAlexandros Lamprineas [ i32 1, label %case1 1394d13896dSAlexandros Lamprineas i32 2, label %case2 ] 1404d13896dSAlexandros Lamprineas case1: 1414d13896dSAlexandros Lamprineas %0 = mul i32 %a, 2 1424d13896dSAlexandros Lamprineas %1 = sub i32 6, 5 1434d13896dSAlexandros Lamprineas br label %bb1 1444d13896dSAlexandros Lamprineas case2: 1454d13896dSAlexandros Lamprineas %2 = and i32 %b, 3 1464d13896dSAlexandros Lamprineas %3 = sdiv i32 8, 2 1474d13896dSAlexandros Lamprineas br label %bb2 1484d13896dSAlexandros Lamprineas bb1: 1494d13896dSAlexandros Lamprineas %4 = add i32 %0, %b 150f11d8c88SAlexandros Lamprineas br label %loop 1514d13896dSAlexandros Lamprineas bb2: 1524d13896dSAlexandros Lamprineas %5 = or i32 %2, %a 153f11d8c88SAlexandros Lamprineas br label %loop 1544d13896dSAlexandros Lamprineas default: 1554d13896dSAlexandros Lamprineas ret void 1564d13896dSAlexandros Lamprineas } 1574d13896dSAlexandros Lamprineas )"; 1584d13896dSAlexandros Lamprineas 1594d13896dSAlexandros Lamprineas Module &M = parseModule(ModuleString); 1604d13896dSAlexandros Lamprineas Function *F = M.getFunction("foo"); 1614d13896dSAlexandros Lamprineas FunctionSpecializer Specializer = getSpecializerFor(F); 1624d13896dSAlexandros Lamprineas InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 1634d13896dSAlexandros Lamprineas 1644d13896dSAlexandros Lamprineas Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 1654d13896dSAlexandros Lamprineas 1664d13896dSAlexandros Lamprineas auto FuncIter = F->begin(); 1675bfefff1SAlexandros Lamprineas BasicBlock &Loop = *++FuncIter; 1684d13896dSAlexandros Lamprineas BasicBlock &Case1 = *++FuncIter; 1694d13896dSAlexandros Lamprineas BasicBlock &Case2 = *++FuncIter; 1704d13896dSAlexandros Lamprineas BasicBlock &BB1 = *++FuncIter; 1714d13896dSAlexandros Lamprineas BasicBlock &BB2 = *++FuncIter; 1724d13896dSAlexandros Lamprineas 1735bfefff1SAlexandros Lamprineas Instruction &Switch = Loop.front(); 1744d13896dSAlexandros Lamprineas Instruction &Mul = Case1.front(); 1754d13896dSAlexandros Lamprineas Instruction &And = Case2.front(); 1764d13896dSAlexandros Lamprineas Instruction &Sdiv = *++Case2.begin(); 1774d13896dSAlexandros Lamprineas Instruction &BrBB2 = Case2.back(); 1784d13896dSAlexandros Lamprineas Instruction &Add = BB1.front(); 1794d13896dSAlexandros Lamprineas Instruction &Or = BB2.front(); 180f11d8c88SAlexandros Lamprineas Instruction &BrLoop = BB2.back(); 1814d13896dSAlexandros Lamprineas 1824d13896dSAlexandros Lamprineas // mul 183c6931c25SHari Limaye Cost Ref = getCodeSizeSavings(Mul); 184c6931c25SHari Limaye Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One); 1855bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 186c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 1874d13896dSAlexandros Lamprineas 1884d13896dSAlexandros Lamprineas // and + or + add 189c6931c25SHari Limaye Ref = getCodeSizeSavings(And) + getCodeSizeSavings(Or) + 190c6931c25SHari Limaye getCodeSizeSavings(Add); 191c6931c25SHari Limaye Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One); 1925bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 193c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 1944d13896dSAlexandros Lamprineas 1955bfefff1SAlexandros Lamprineas // switch + sdiv + br + br 196c6931c25SHari Limaye Ref = getCodeSizeSavings(Switch) + 197c6931c25SHari Limaye getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) + 198c6931c25SHari Limaye getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) + 199c6931c25SHari Limaye getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false); 200c6931c25SHari Limaye Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One); 2015bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 202c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 203c6931c25SHari Limaye 204c6931c25SHari Limaye // Latency. 205c6931c25SHari Limaye Ref = getLatencySavings(F); 206c6931c25SHari Limaye Test = Visitor.getLatencySavingsForKnownConstants(); 207c6931c25SHari Limaye EXPECT_EQ(Test, Ref); 208c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 2094d13896dSAlexandros Lamprineas } 2104d13896dSAlexandros Lamprineas 2114d13896dSAlexandros Lamprineas TEST_F(FunctionSpecializationTest, BranchInst) { 2124d13896dSAlexandros Lamprineas const char *ModuleString = R"( 2134d13896dSAlexandros Lamprineas define void @foo(i32 %a, i32 %b, i1 %cond) { 2144d13896dSAlexandros Lamprineas entry: 215f11d8c88SAlexandros Lamprineas br label %loop 216f11d8c88SAlexandros Lamprineas loop: 217c2d19002SAlexandros Lamprineas br i1 %cond, label %bb0, label %bb3 2184d13896dSAlexandros Lamprineas bb0: 2194d13896dSAlexandros Lamprineas %0 = mul i32 %a, 2 2204d13896dSAlexandros Lamprineas %1 = sub i32 6, 5 221c2d19002SAlexandros Lamprineas br i1 %cond, label %bb1, label %bb2 2224d13896dSAlexandros Lamprineas bb1: 2234d13896dSAlexandros Lamprineas %2 = add i32 %0, %b 2244d13896dSAlexandros Lamprineas %3 = sdiv i32 8, 2 225c2d19002SAlexandros Lamprineas br label %bb2 2264d13896dSAlexandros Lamprineas bb2: 227c2d19002SAlexandros Lamprineas br label %loop 228c2d19002SAlexandros Lamprineas bb3: 2294d13896dSAlexandros Lamprineas ret void 2304d13896dSAlexandros Lamprineas } 2314d13896dSAlexandros Lamprineas )"; 2324d13896dSAlexandros Lamprineas 2334d13896dSAlexandros Lamprineas Module &M = parseModule(ModuleString); 2344d13896dSAlexandros Lamprineas Function *F = M.getFunction("foo"); 2354d13896dSAlexandros Lamprineas FunctionSpecializer Specializer = getSpecializerFor(F); 2364d13896dSAlexandros Lamprineas InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 2374d13896dSAlexandros Lamprineas 2384d13896dSAlexandros Lamprineas Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 2394d13896dSAlexandros Lamprineas Constant *False = ConstantInt::getFalse(M.getContext()); 2404d13896dSAlexandros Lamprineas 2414d13896dSAlexandros Lamprineas auto FuncIter = F->begin(); 2425bfefff1SAlexandros Lamprineas BasicBlock &Loop = *++FuncIter; 2434d13896dSAlexandros Lamprineas BasicBlock &BB0 = *++FuncIter; 2444d13896dSAlexandros Lamprineas BasicBlock &BB1 = *++FuncIter; 245c2d19002SAlexandros Lamprineas BasicBlock &BB2 = *++FuncIter; 2464d13896dSAlexandros Lamprineas 2475bfefff1SAlexandros Lamprineas Instruction &Branch = Loop.front(); 2484d13896dSAlexandros Lamprineas Instruction &Mul = BB0.front(); 2494d13896dSAlexandros Lamprineas Instruction &Sub = *++BB0.begin(); 250c2d19002SAlexandros Lamprineas Instruction &BrBB1BB2 = BB0.back(); 2514d13896dSAlexandros Lamprineas Instruction &Add = BB1.front(); 2524d13896dSAlexandros Lamprineas Instruction &Sdiv = *++BB1.begin(); 253c2d19002SAlexandros Lamprineas Instruction &BrBB2 = BB1.back(); 254c2d19002SAlexandros Lamprineas Instruction &BrLoop = BB2.front(); 2554d13896dSAlexandros Lamprineas 2564d13896dSAlexandros Lamprineas // mul 257c6931c25SHari Limaye Cost Ref = getCodeSizeSavings(Mul); 258c6931c25SHari Limaye Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One); 2595bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 260c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 2614d13896dSAlexandros Lamprineas 2624d13896dSAlexandros Lamprineas // add 263c6931c25SHari Limaye Ref = getCodeSizeSavings(Add); 264c6931c25SHari Limaye Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One); 2655bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 266c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 2674d13896dSAlexandros Lamprineas 2685bfefff1SAlexandros Lamprineas // branch + sub + br + sdiv + br 269c6931c25SHari Limaye Ref = getCodeSizeSavings(Branch) + 270c6931c25SHari Limaye getCodeSizeSavings(Sub, /*HasLatencySavings=*/false) + 271c6931c25SHari Limaye getCodeSizeSavings(BrBB1BB2) + 272c6931c25SHari Limaye getCodeSizeSavings(Sdiv, /*HasLatencySavings=*/false) + 273c6931c25SHari Limaye getCodeSizeSavings(BrBB2, /*HasLatencySavings=*/false) + 274c6931c25SHari Limaye getCodeSizeSavings(BrLoop, /*HasLatencySavings=*/false); 275c6931c25SHari Limaye Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), False); 2765bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 277c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 278c6931c25SHari Limaye 279c6931c25SHari Limaye // Latency. 280c6931c25SHari Limaye Ref = getLatencySavings(F); 281c6931c25SHari Limaye Test = Visitor.getLatencySavingsForKnownConstants(); 282c6931c25SHari Limaye EXPECT_EQ(Test, Ref); 283c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 2844d13896dSAlexandros Lamprineas } 2854d13896dSAlexandros Lamprineas 2866472cb1eSAlexandros Lamprineas TEST_F(FunctionSpecializationTest, SelectInst) { 2876472cb1eSAlexandros Lamprineas const char *ModuleString = R"( 2886472cb1eSAlexandros Lamprineas define i32 @foo(i1 %cond, i32 %a, i32 %b) { 2896472cb1eSAlexandros Lamprineas %sel = select i1 %cond, i32 %a, i32 %b 2906472cb1eSAlexandros Lamprineas ret i32 %sel 2916472cb1eSAlexandros Lamprineas } 2926472cb1eSAlexandros Lamprineas )"; 2936472cb1eSAlexandros Lamprineas 2946472cb1eSAlexandros Lamprineas Module &M = parseModule(ModuleString); 2956472cb1eSAlexandros Lamprineas Function *F = M.getFunction("foo"); 2966472cb1eSAlexandros Lamprineas FunctionSpecializer Specializer = getSpecializerFor(F); 2976472cb1eSAlexandros Lamprineas InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 2986472cb1eSAlexandros Lamprineas 2996472cb1eSAlexandros Lamprineas Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 3006472cb1eSAlexandros Lamprineas Constant *Zero = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 0); 3016472cb1eSAlexandros Lamprineas Constant *False = ConstantInt::getFalse(M.getContext()); 3026472cb1eSAlexandros Lamprineas Instruction &Select = *F->front().begin(); 3036472cb1eSAlexandros Lamprineas 304c6931c25SHari Limaye Cost RefCodeSize = getCodeSizeSavings(Select); 305c6931c25SHari Limaye Cost RefLatency = getLatencySavings(F); 306c6931c25SHari Limaye 307c6931c25SHari Limaye Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False); 308c6931c25SHari Limaye EXPECT_TRUE(TestCodeSize == 0); 309c6931c25SHari Limaye TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One); 310c6931c25SHari Limaye EXPECT_TRUE(TestCodeSize == 0); 311c6931c25SHari Limaye Cost TestLatency = Visitor.getLatencySavingsForKnownConstants(); 312c6931c25SHari Limaye EXPECT_TRUE(TestLatency == 0); 313c6931c25SHari Limaye 314c6931c25SHari Limaye TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(2), Zero); 315c6931c25SHari Limaye EXPECT_EQ(TestCodeSize, RefCodeSize); 316c6931c25SHari Limaye EXPECT_TRUE(TestCodeSize > 0); 317c6931c25SHari Limaye TestLatency = Visitor.getLatencySavingsForKnownConstants(); 318c6931c25SHari Limaye EXPECT_EQ(TestLatency, RefLatency); 319c6931c25SHari Limaye EXPECT_TRUE(TestLatency > 0); 3206472cb1eSAlexandros Lamprineas } 3216472cb1eSAlexandros Lamprineas 3224d13896dSAlexandros Lamprineas TEST_F(FunctionSpecializationTest, Misc) { 3234d13896dSAlexandros Lamprineas const char *ModuleString = R"( 324cae00b2aSAlexandros Lamprineas %struct_t = type { [8 x i16], [8 x i16], i32, i32, i32, ptr, [8 x i8] } 325cae00b2aSAlexandros Lamprineas @g = constant %struct_t zeroinitializer, align 16 3264d13896dSAlexandros Lamprineas 3275400257dSAlexandros Lamprineas declare i32 @llvm.smax.i32(i32, i32) 3285400257dSAlexandros Lamprineas declare i32 @bar(i32) 3295400257dSAlexandros Lamprineas 3305400257dSAlexandros Lamprineas define i32 @foo(i8 %a, i1 %cond, ptr %b, i32 %c) { 3314d13896dSAlexandros Lamprineas %cmp = icmp eq i8 %a, 10 332cae00b2aSAlexandros Lamprineas %ext = zext i1 %cmp to i64 333cae00b2aSAlexandros Lamprineas %sel = select i1 %cond, i64 %ext, i64 1 334cae00b2aSAlexandros Lamprineas %gep = getelementptr inbounds %struct_t, ptr %b, i64 %sel, i32 4 3354d13896dSAlexandros Lamprineas %ld = load i32, ptr %gep 3365400257dSAlexandros Lamprineas %fr = freeze i32 %ld 3375400257dSAlexandros Lamprineas %smax = call i32 @llvm.smax.i32(i32 %fr, i32 1) 3385400257dSAlexandros Lamprineas %call = call i32 @bar(i32 %smax) 3395400257dSAlexandros Lamprineas %fr2 = freeze i32 %c 3405400257dSAlexandros Lamprineas %add = add i32 %call, %fr2 3415400257dSAlexandros Lamprineas ret i32 %add 3424d13896dSAlexandros Lamprineas } 3434d13896dSAlexandros Lamprineas )"; 3444d13896dSAlexandros Lamprineas 3454d13896dSAlexandros Lamprineas Module &M = parseModule(ModuleString); 3464d13896dSAlexandros Lamprineas Function *F = M.getFunction("foo"); 3474d13896dSAlexandros Lamprineas FunctionSpecializer Specializer = getSpecializerFor(F); 3484d13896dSAlexandros Lamprineas InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 3494d13896dSAlexandros Lamprineas 3504d13896dSAlexandros Lamprineas GlobalVariable *GV = M.getGlobalVariable("g"); 3514d13896dSAlexandros Lamprineas Constant *One = ConstantInt::get(IntegerType::getInt8Ty(M.getContext()), 1); 3524d13896dSAlexandros Lamprineas Constant *True = ConstantInt::getTrue(M.getContext()); 3535400257dSAlexandros Lamprineas Constant *Undef = UndefValue::get(IntegerType::getInt32Ty(M.getContext())); 3544d13896dSAlexandros Lamprineas 3554d13896dSAlexandros Lamprineas auto BlockIter = F->front().begin(); 3564d13896dSAlexandros Lamprineas Instruction &Icmp = *BlockIter++; 3574d13896dSAlexandros Lamprineas Instruction &Zext = *BlockIter++; 3584d13896dSAlexandros Lamprineas Instruction &Select = *BlockIter++; 3594d13896dSAlexandros Lamprineas Instruction &Gep = *BlockIter++; 3604d13896dSAlexandros Lamprineas Instruction &Load = *BlockIter++; 3615400257dSAlexandros Lamprineas Instruction &Freeze = *BlockIter++; 3625400257dSAlexandros Lamprineas Instruction &Smax = *BlockIter++; 3634d13896dSAlexandros Lamprineas 3644d13896dSAlexandros Lamprineas // icmp + zext 365c6931c25SHari Limaye Cost Ref = getCodeSizeSavings(Icmp) + getCodeSizeSavings(Zext); 366c6931c25SHari Limaye Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One); 3675bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 368c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 3694d13896dSAlexandros Lamprineas 3704d13896dSAlexandros Lamprineas // select 371c6931c25SHari Limaye Ref = getCodeSizeSavings(Select); 372c6931c25SHari Limaye Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), True); 3735bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 374c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 3754d13896dSAlexandros Lamprineas 3765400257dSAlexandros Lamprineas // gep + load + freeze + smax 377c6931c25SHari Limaye Ref = getCodeSizeSavings(Gep) + getCodeSizeSavings(Load) + 378c6931c25SHari Limaye getCodeSizeSavings(Freeze) + getCodeSizeSavings(Smax); 379c6931c25SHari Limaye Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), GV); 3805bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 381c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 3825400257dSAlexandros Lamprineas 383c6931c25SHari Limaye Test = Visitor.getCodeSizeSavingsForArg(F->getArg(3), Undef); 384c6931c25SHari Limaye EXPECT_TRUE(Test == 0); 385c6931c25SHari Limaye 386c6931c25SHari Limaye // Latency. 387c6931c25SHari Limaye Ref = getLatencySavings(F); 388c6931c25SHari Limaye Test = Visitor.getLatencySavingsForKnownConstants(); 389c6931c25SHari Limaye EXPECT_EQ(Test, Ref); 390c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 3914d13896dSAlexandros Lamprineas } 392893d3a61SAlexandros Lamprineas 393893d3a61SAlexandros Lamprineas TEST_F(FunctionSpecializationTest, PhiNode) { 394893d3a61SAlexandros Lamprineas const char *ModuleString = R"( 395893d3a61SAlexandros Lamprineas define void @foo(i32 %a, i32 %b, i32 %i) { 396893d3a61SAlexandros Lamprineas entry: 397893d3a61SAlexandros Lamprineas br label %loop 398893d3a61SAlexandros Lamprineas loop: 399893d3a61SAlexandros Lamprineas %0 = phi i32 [ %a, %entry ], [ %3, %bb ] 400893d3a61SAlexandros Lamprineas switch i32 %i, label %default 401893d3a61SAlexandros Lamprineas [ i32 1, label %case1 402893d3a61SAlexandros Lamprineas i32 2, label %case2 ] 403893d3a61SAlexandros Lamprineas case1: 404893d3a61SAlexandros Lamprineas %1 = add i32 %0, 1 405893d3a61SAlexandros Lamprineas br label %bb 406893d3a61SAlexandros Lamprineas case2: 407893d3a61SAlexandros Lamprineas %2 = phi i32 [ %a, %entry ], [ %0, %loop ] 408893d3a61SAlexandros Lamprineas br label %bb 409893d3a61SAlexandros Lamprineas bb: 410893d3a61SAlexandros Lamprineas %3 = phi i32 [ %b, %case1 ], [ %2, %case2 ], [ %3, %bb ] 411893d3a61SAlexandros Lamprineas %4 = icmp eq i32 %3, 1 412893d3a61SAlexandros Lamprineas br i1 %4, label %bb, label %loop 413893d3a61SAlexandros Lamprineas default: 414893d3a61SAlexandros Lamprineas ret void 415893d3a61SAlexandros Lamprineas } 416893d3a61SAlexandros Lamprineas )"; 417893d3a61SAlexandros Lamprineas 418893d3a61SAlexandros Lamprineas Module &M = parseModule(ModuleString); 419893d3a61SAlexandros Lamprineas Function *F = M.getFunction("foo"); 420893d3a61SAlexandros Lamprineas FunctionSpecializer Specializer = getSpecializerFor(F); 421893d3a61SAlexandros Lamprineas InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 422893d3a61SAlexandros Lamprineas 423893d3a61SAlexandros Lamprineas Constant *One = ConstantInt::get(IntegerType::getInt32Ty(M.getContext()), 1); 424893d3a61SAlexandros Lamprineas 425893d3a61SAlexandros Lamprineas auto FuncIter = F->begin(); 426893d3a61SAlexandros Lamprineas BasicBlock &Loop = *++FuncIter; 427893d3a61SAlexandros Lamprineas BasicBlock &Case1 = *++FuncIter; 428893d3a61SAlexandros Lamprineas BasicBlock &Case2 = *++FuncIter; 429893d3a61SAlexandros Lamprineas BasicBlock &BB = *++FuncIter; 430893d3a61SAlexandros Lamprineas 431893d3a61SAlexandros Lamprineas Instruction &PhiLoop = Loop.front(); 4325bfefff1SAlexandros Lamprineas Instruction &Switch = Loop.back(); 433893d3a61SAlexandros Lamprineas Instruction &Add = Case1.front(); 434893d3a61SAlexandros Lamprineas Instruction &PhiCase2 = Case2.front(); 435893d3a61SAlexandros Lamprineas Instruction &BrBB = Case2.back(); 436893d3a61SAlexandros Lamprineas Instruction &PhiBB = BB.front(); 437893d3a61SAlexandros Lamprineas Instruction &Icmp = *++BB.begin(); 4385bfefff1SAlexandros Lamprineas Instruction &Branch = BB.back(); 439893d3a61SAlexandros Lamprineas 440c6931c25SHari Limaye Cost Test = Visitor.getCodeSizeSavingsForArg(F->getArg(0), One); 441c6931c25SHari Limaye EXPECT_TRUE(Test == 0); 442893d3a61SAlexandros Lamprineas 443c6931c25SHari Limaye Test = Visitor.getCodeSizeSavingsForArg(F->getArg(1), One); 444c6931c25SHari Limaye EXPECT_TRUE(Test == 0); 445c6931c25SHari Limaye 446c6931c25SHari Limaye Test = Visitor.getLatencySavingsForKnownConstants(); 447c6931c25SHari Limaye EXPECT_TRUE(Test == 0); 448893d3a61SAlexandros Lamprineas 4495bfefff1SAlexandros Lamprineas // switch + phi + br 450c6931c25SHari Limaye Cost Ref = getCodeSizeSavings(Switch) + 451c6931c25SHari Limaye getCodeSizeSavings(PhiCase2, /*HasLatencySavings=*/false) + 452c6931c25SHari Limaye getCodeSizeSavings(BrBB, /*HasLatencySavings=*/false); 453c6931c25SHari Limaye Test = Visitor.getCodeSizeSavingsForArg(F->getArg(2), One); 4545bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 455c6931c25SHari Limaye EXPECT_TRUE(Test > 0 && Test > 0); 456893d3a61SAlexandros Lamprineas 4575bfefff1SAlexandros Lamprineas // phi + phi + add + icmp + branch 458c6931c25SHari Limaye Ref = getCodeSizeSavings(PhiBB) + getCodeSizeSavings(PhiLoop) + 459c6931c25SHari Limaye getCodeSizeSavings(Add) + getCodeSizeSavings(Icmp) + 460c6931c25SHari Limaye getCodeSizeSavings(Branch); 461c6931c25SHari Limaye Test = Visitor.getCodeSizeSavingsFromPendingPHIs(); 4625bfefff1SAlexandros Lamprineas EXPECT_EQ(Test, Ref); 463c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 464c6931c25SHari Limaye 465c6931c25SHari Limaye // Latency. 466c6931c25SHari Limaye Ref = getLatencySavings(F); 467c6931c25SHari Limaye Test = Visitor.getLatencySavingsForKnownConstants(); 468c6931c25SHari Limaye EXPECT_EQ(Test, Ref); 469c6931c25SHari Limaye EXPECT_TRUE(Test > 0); 470893d3a61SAlexandros Lamprineas } 471893d3a61SAlexandros Lamprineas 472*5f30b1aaSHari Limaye TEST_F(FunctionSpecializationTest, BinOp) { 473*5f30b1aaSHari Limaye // Verify that we can handle binary operators even when only one operand is 474*5f30b1aaSHari Limaye // constant. 475*5f30b1aaSHari Limaye const char *ModuleString = R"( 476*5f30b1aaSHari Limaye define i32 @foo(i1 %a, i1 %b) { 477*5f30b1aaSHari Limaye %and1 = and i1 %a, %b 478*5f30b1aaSHari Limaye %and2 = and i1 %b, %and1 479*5f30b1aaSHari Limaye %sel = select i1 %and2, i32 1, i32 0 480*5f30b1aaSHari Limaye ret i32 %sel 481*5f30b1aaSHari Limaye } 482*5f30b1aaSHari Limaye )"; 483*5f30b1aaSHari Limaye 484*5f30b1aaSHari Limaye Module &M = parseModule(ModuleString); 485*5f30b1aaSHari Limaye Function *F = M.getFunction("foo"); 486*5f30b1aaSHari Limaye FunctionSpecializer Specializer = getSpecializerFor(F); 487*5f30b1aaSHari Limaye InstCostVisitor Visitor = Specializer.getInstCostVisitorFor(F); 488*5f30b1aaSHari Limaye 489*5f30b1aaSHari Limaye Constant *False = ConstantInt::getFalse(M.getContext()); 490*5f30b1aaSHari Limaye BasicBlock &BB = F->front(); 491*5f30b1aaSHari Limaye Instruction &And1 = BB.front(); 492*5f30b1aaSHari Limaye Instruction &And2 = *++BB.begin(); 493*5f30b1aaSHari Limaye Instruction &Select = *++BB.begin(); 494*5f30b1aaSHari Limaye 495*5f30b1aaSHari Limaye Cost RefCodeSize = getCodeSizeSavings(And1) + getCodeSizeSavings(And2) + 496*5f30b1aaSHari Limaye getCodeSizeSavings(Select); 497*5f30b1aaSHari Limaye Cost RefLatency = getLatencySavings(F); 498*5f30b1aaSHari Limaye 499*5f30b1aaSHari Limaye Cost TestCodeSize = Visitor.getCodeSizeSavingsForArg(F->getArg(0), False); 500*5f30b1aaSHari Limaye Cost TestLatency = Visitor.getLatencySavingsForKnownConstants(); 501*5f30b1aaSHari Limaye 502*5f30b1aaSHari Limaye EXPECT_EQ(TestCodeSize, RefCodeSize); 503*5f30b1aaSHari Limaye EXPECT_TRUE(TestCodeSize > 0); 504*5f30b1aaSHari Limaye EXPECT_EQ(TestLatency, RefLatency); 505*5f30b1aaSHari Limaye EXPECT_TRUE(TestLatency > 0); 506*5f30b1aaSHari Limaye } 507