1443825c5SNoah Goldstein //===- LICMTest.cpp - LICM unit tests -------------------------------------===// 2443825c5SNoah Goldstein // 3443825c5SNoah Goldstein // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4443825c5SNoah Goldstein // See https://llvm.org/LICENSE.txt for license information. 5443825c5SNoah Goldstein // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6443825c5SNoah Goldstein // 7443825c5SNoah Goldstein //===----------------------------------------------------------------------===// 8443825c5SNoah Goldstein 9443825c5SNoah Goldstein #include "llvm/Analysis/TargetTransformInfo.h" 10443825c5SNoah Goldstein #include "llvm/AsmParser/Parser.h" 1174deadf1SNikita Popov #include "llvm/IR/Module.h" 12443825c5SNoah Goldstein #include "llvm/MC/TargetRegistry.h" 13443825c5SNoah Goldstein #include "llvm/Passes/PassBuilder.h" 14443825c5SNoah Goldstein #include "llvm/Support/TargetSelect.h" 15443825c5SNoah Goldstein #include "llvm/Target/TargetMachine.h" 16443825c5SNoah Goldstein #include "llvm/Transforms/InstCombine/InstCombine.h" 17443825c5SNoah Goldstein 18443825c5SNoah Goldstein #include "gtest/gtest.h" 19443825c5SNoah Goldstein 20443825c5SNoah Goldstein #include <random> 21443825c5SNoah Goldstein 22443825c5SNoah Goldstein namespace llvm { 23*bb3f5e1fSMatin Raayai static std::unique_ptr<TargetMachine> initTM() { 24443825c5SNoah Goldstein LLVMInitializeX86TargetInfo(); 25443825c5SNoah Goldstein LLVMInitializeX86Target(); 26443825c5SNoah Goldstein LLVMInitializeX86TargetMC(); 27443825c5SNoah Goldstein 28443825c5SNoah Goldstein auto TT(Triple::normalize("x86_64--")); 29443825c5SNoah Goldstein std::string Error; 30443825c5SNoah Goldstein const Target *TheTarget = TargetRegistry::lookupTarget(TT, Error); 31*bb3f5e1fSMatin Raayai return std::unique_ptr<TargetMachine>( 32443825c5SNoah Goldstein TheTarget->createTargetMachine(TT, "", "", TargetOptions(), std::nullopt, 33*bb3f5e1fSMatin Raayai std::nullopt, CodeGenOptLevel::Default)); 34443825c5SNoah Goldstein } 35443825c5SNoah Goldstein 36443825c5SNoah Goldstein struct TernTester { 37443825c5SNoah Goldstein unsigned NElem; 38443825c5SNoah Goldstein unsigned ElemWidth; 39443825c5SNoah Goldstein std::mt19937_64 Rng; 40443825c5SNoah Goldstein unsigned ImmVal; 41443825c5SNoah Goldstein SmallVector<uint64_t, 16> VecElems[3]; 42443825c5SNoah Goldstein 43443825c5SNoah Goldstein void updateImm(uint8_t NewImmVal) { ImmVal = NewImmVal; } 44443825c5SNoah Goldstein void updateNElem(unsigned NewNElem) { 45443825c5SNoah Goldstein NElem = NewNElem; 46443825c5SNoah Goldstein for (unsigned I = 0; I < 3; ++I) { 47443825c5SNoah Goldstein VecElems[I].resize(NElem); 48443825c5SNoah Goldstein } 49443825c5SNoah Goldstein } 50443825c5SNoah Goldstein void updateElemWidth(unsigned NewElemWidth) { 51443825c5SNoah Goldstein ElemWidth = NewElemWidth; 52443825c5SNoah Goldstein assert(ElemWidth == 32 || ElemWidth == 64); 53443825c5SNoah Goldstein } 54443825c5SNoah Goldstein 55443825c5SNoah Goldstein uint64_t getElemMask() const { 56443825c5SNoah Goldstein return (~uint64_t(0)) >> ((ElemWidth - 0) % 64); 57443825c5SNoah Goldstein } 58443825c5SNoah Goldstein 59443825c5SNoah Goldstein void RandomizeVecArgs() { 60443825c5SNoah Goldstein uint64_t ElemMask = getElemMask(); 61443825c5SNoah Goldstein for (unsigned I = 0; I < 3; ++I) { 62443825c5SNoah Goldstein for (unsigned J = 0; J < NElem; ++J) { 63443825c5SNoah Goldstein VecElems[I][J] = Rng() & ElemMask; 64443825c5SNoah Goldstein } 65443825c5SNoah Goldstein } 66443825c5SNoah Goldstein } 67443825c5SNoah Goldstein 68443825c5SNoah Goldstein std::pair<std::string, std::string> getScalarInfo() const { 69443825c5SNoah Goldstein switch (ElemWidth) { 70443825c5SNoah Goldstein case 32: 71443825c5SNoah Goldstein return {"i32", "d"}; 72443825c5SNoah Goldstein case 64: 73443825c5SNoah Goldstein return {"i64", "q"}; 74443825c5SNoah Goldstein default: 75443825c5SNoah Goldstein llvm_unreachable("Invalid ElemWidth"); 76443825c5SNoah Goldstein } 77443825c5SNoah Goldstein } 78443825c5SNoah Goldstein std::string getScalarType() const { return getScalarInfo().first; } 79443825c5SNoah Goldstein std::string getScalarExt() const { return getScalarInfo().second; } 80443825c5SNoah Goldstein std::string getVecType() const { 81443825c5SNoah Goldstein return "<" + Twine(NElem).str() + " x " + getScalarType() + ">"; 82443825c5SNoah Goldstein }; 83443825c5SNoah Goldstein 84443825c5SNoah Goldstein std::string getVecWidth() const { return Twine(NElem * ElemWidth).str(); } 85443825c5SNoah Goldstein std::string getFunctionName() const { 86443825c5SNoah Goldstein return "@llvm.x86.avx512.pternlog." + getScalarExt() + "." + getVecWidth(); 87443825c5SNoah Goldstein } 88443825c5SNoah Goldstein std::string getFunctionDecl() const { 89443825c5SNoah Goldstein return "declare " + getVecType() + getFunctionName() + "(" + getVecType() + 90443825c5SNoah Goldstein ", " + getVecType() + ", " + getVecType() + ", " + "i32 immarg)"; 91443825c5SNoah Goldstein } 92443825c5SNoah Goldstein 93443825c5SNoah Goldstein std::string getVecN(unsigned N) const { 94443825c5SNoah Goldstein assert(N < 3); 95443825c5SNoah Goldstein std::string VecStr = getVecType() + " <"; 96443825c5SNoah Goldstein for (unsigned I = 0; I < VecElems[N].size(); ++I) { 97443825c5SNoah Goldstein if (I != 0) 98443825c5SNoah Goldstein VecStr += ", "; 99443825c5SNoah Goldstein VecStr += getScalarType() + " " + Twine(VecElems[N][I]).str(); 100443825c5SNoah Goldstein } 101443825c5SNoah Goldstein return VecStr + ">"; 102443825c5SNoah Goldstein } 103443825c5SNoah Goldstein std::string getFunctionCall() const { 104443825c5SNoah Goldstein return "tail call " + getVecType() + " " + getFunctionName() + "(" + 105443825c5SNoah Goldstein getVecN(0) + ", " + getVecN(1) + ", " + getVecN(2) + ", " + "i32 " + 106443825c5SNoah Goldstein Twine(ImmVal).str() + ")"; 107443825c5SNoah Goldstein } 108443825c5SNoah Goldstein 109443825c5SNoah Goldstein std::string getTestText() const { 110443825c5SNoah Goldstein return getFunctionDecl() + "\ndefine " + getVecType() + 111443825c5SNoah Goldstein "@foo() {\n%r = " + getFunctionCall() + "\nret " + getVecType() + 112443825c5SNoah Goldstein " %r\n}\n"; 113443825c5SNoah Goldstein } 114443825c5SNoah Goldstein 115443825c5SNoah Goldstein void checkResult(const Value *V) { 116443825c5SNoah Goldstein auto GetValElem = [&](unsigned Idx) -> uint64_t { 117443825c5SNoah Goldstein if (auto *CV = dyn_cast<ConstantDataVector>(V)) 118443825c5SNoah Goldstein return CV->getElementAsInteger(Idx); 119443825c5SNoah Goldstein 120443825c5SNoah Goldstein auto *C = dyn_cast<Constant>(V); 121443825c5SNoah Goldstein assert(C); 122443825c5SNoah Goldstein if (C->isNullValue()) 123443825c5SNoah Goldstein return 0; 124443825c5SNoah Goldstein if (C->isAllOnesValue()) 125443825c5SNoah Goldstein return ((~uint64_t(0)) >> (ElemWidth % 64)); 126443825c5SNoah Goldstein if (C->isOneValue()) 127443825c5SNoah Goldstein return 1; 128443825c5SNoah Goldstein 129443825c5SNoah Goldstein llvm_unreachable("Unknown constant type"); 130443825c5SNoah Goldstein }; 131443825c5SNoah Goldstein 132443825c5SNoah Goldstein auto ComputeBit = [&](uint64_t A, uint64_t B, uint64_t C) -> uint64_t { 133443825c5SNoah Goldstein unsigned BitIdx = ((A & 1) << 2) | ((B & 1) << 1) | (C & 1); 134443825c5SNoah Goldstein return (ImmVal >> BitIdx) & 1; 135443825c5SNoah Goldstein }; 136443825c5SNoah Goldstein 137443825c5SNoah Goldstein for (unsigned I = 0; I < NElem; ++I) { 138443825c5SNoah Goldstein uint64_t Expec = 0; 139443825c5SNoah Goldstein uint64_t AEle = VecElems[0][I]; 140443825c5SNoah Goldstein uint64_t BEle = VecElems[1][I]; 141443825c5SNoah Goldstein uint64_t CEle = VecElems[2][I]; 142443825c5SNoah Goldstein for (unsigned J = 0; J < ElemWidth; ++J) { 143443825c5SNoah Goldstein Expec |= ComputeBit(AEle >> J, BEle >> J, CEle >> J) << J; 144443825c5SNoah Goldstein } 145443825c5SNoah Goldstein 146443825c5SNoah Goldstein ASSERT_EQ(Expec, GetValElem(I)); 147443825c5SNoah Goldstein } 148443825c5SNoah Goldstein } 149443825c5SNoah Goldstein 150443825c5SNoah Goldstein void check(LLVMContext &Ctx, FunctionPassManager &FPM, 151443825c5SNoah Goldstein FunctionAnalysisManager &FAM) { 152443825c5SNoah Goldstein SMDiagnostic Error; 153443825c5SNoah Goldstein std::unique_ptr<Module> M = parseAssemblyString(getTestText(), Error, Ctx); 154443825c5SNoah Goldstein ASSERT_TRUE(M); 155443825c5SNoah Goldstein Function *F = M->getFunction("foo"); 156443825c5SNoah Goldstein ASSERT_TRUE(F); 157443825c5SNoah Goldstein ASSERT_EQ(F->getInstructionCount(), 2u); 158cacbe71aSYingwei Zheng FAM.clear(); 159443825c5SNoah Goldstein FPM.run(*F, FAM); 160443825c5SNoah Goldstein ASSERT_EQ(F->getInstructionCount(), 1u); 161443825c5SNoah Goldstein ASSERT_EQ(F->size(), 1u); 162443825c5SNoah Goldstein const Instruction *I = F->begin()->getTerminator(); 163443825c5SNoah Goldstein ASSERT_TRUE(I); 164443825c5SNoah Goldstein ASSERT_EQ(I->getNumOperands(), 1u); 165443825c5SNoah Goldstein checkResult(I->getOperand(0)); 166443825c5SNoah Goldstein } 167443825c5SNoah Goldstein }; 168443825c5SNoah Goldstein 169443825c5SNoah Goldstein TEST(TernlogTest, TestConstantFolding) { 170443825c5SNoah Goldstein LLVMContext Ctx; 171443825c5SNoah Goldstein FunctionAnalysisManager FAM; 172443825c5SNoah Goldstein FunctionPassManager FPM; 173443825c5SNoah Goldstein PassBuilder PB; 174443825c5SNoah Goldstein LoopAnalysisManager LAM; 175443825c5SNoah Goldstein CGSCCAnalysisManager CGAM; 176443825c5SNoah Goldstein ModuleAnalysisManager MAM; 177443825c5SNoah Goldstein TargetIRAnalysis TIRA = TargetIRAnalysis( 178443825c5SNoah Goldstein [&](const Function &F) { return initTM()->getTargetTransformInfo(F); }); 179443825c5SNoah Goldstein 180443825c5SNoah Goldstein FAM.registerPass([&] { return TIRA; }); 181443825c5SNoah Goldstein PB.registerModuleAnalyses(MAM); 182443825c5SNoah Goldstein PB.registerCGSCCAnalyses(CGAM); 183443825c5SNoah Goldstein PB.registerFunctionAnalyses(FAM); 184443825c5SNoah Goldstein PB.registerLoopAnalyses(LAM); 185443825c5SNoah Goldstein PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 186443825c5SNoah Goldstein 187443825c5SNoah Goldstein FPM.addPass(InstCombinePass()); 188443825c5SNoah Goldstein TernTester TT; 189443825c5SNoah Goldstein for (unsigned NElem = 2; NElem < 16; NElem += NElem) { 190443825c5SNoah Goldstein TT.updateNElem(NElem); 191443825c5SNoah Goldstein for (unsigned ElemWidth = 32; ElemWidth <= 64; ElemWidth += ElemWidth) { 192443825c5SNoah Goldstein if (ElemWidth * NElem > 512 || ElemWidth * NElem < 128) 193443825c5SNoah Goldstein continue; 194443825c5SNoah Goldstein TT.updateElemWidth(ElemWidth); 195443825c5SNoah Goldstein TT.RandomizeVecArgs(); 196443825c5SNoah Goldstein for (unsigned Imm = 0; Imm < 256; ++Imm) { 197443825c5SNoah Goldstein TT.updateImm(Imm); 198443825c5SNoah Goldstein TT.check(Ctx, FPM, FAM); 199443825c5SNoah Goldstein } 200443825c5SNoah Goldstein } 201443825c5SNoah Goldstein } 202443825c5SNoah Goldstein } 203443825c5SNoah Goldstein } // namespace llvm 204