1 //===- LICMTest.cpp - LICM unit tests -------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "llvm/Analysis/TargetTransformInfo.h" 10 #include "llvm/AsmParser/Parser.h" 11 #include "llvm/IR/Module.h" 12 #include "llvm/MC/TargetRegistry.h" 13 #include "llvm/Passes/PassBuilder.h" 14 #include "llvm/Support/TargetSelect.h" 15 #include "llvm/Target/TargetMachine.h" 16 #include "llvm/Transforms/InstCombine/InstCombine.h" 17 18 #include "gtest/gtest.h" 19 20 #include <random> 21 22 namespace llvm { 23 static std::unique_ptr<TargetMachine> initTM() { 24 LLVMInitializeX86TargetInfo(); 25 LLVMInitializeX86Target(); 26 LLVMInitializeX86TargetMC(); 27 28 auto TT(Triple::normalize("x86_64--")); 29 std::string Error; 30 const Target *TheTarget = TargetRegistry::lookupTarget(TT, Error); 31 return std::unique_ptr<TargetMachine>( 32 TheTarget->createTargetMachine(TT, "", "", TargetOptions(), std::nullopt, 33 std::nullopt, CodeGenOptLevel::Default)); 34 } 35 36 struct TernTester { 37 unsigned NElem; 38 unsigned ElemWidth; 39 std::mt19937_64 Rng; 40 unsigned ImmVal; 41 SmallVector<uint64_t, 16> VecElems[3]; 42 43 void updateImm(uint8_t NewImmVal) { ImmVal = NewImmVal; } 44 void updateNElem(unsigned NewNElem) { 45 NElem = NewNElem; 46 for (unsigned I = 0; I < 3; ++I) { 47 VecElems[I].resize(NElem); 48 } 49 } 50 void updateElemWidth(unsigned NewElemWidth) { 51 ElemWidth = NewElemWidth; 52 assert(ElemWidth == 32 || ElemWidth == 64); 53 } 54 55 uint64_t getElemMask() const { 56 return (~uint64_t(0)) >> ((ElemWidth - 0) % 64); 57 } 58 59 void RandomizeVecArgs() { 60 uint64_t ElemMask = getElemMask(); 61 for (unsigned I = 0; I < 3; ++I) { 62 for (unsigned J = 0; J < NElem; ++J) { 63 VecElems[I][J] = Rng() & ElemMask; 64 } 65 } 66 } 67 68 std::pair<std::string, std::string> getScalarInfo() const { 69 switch (ElemWidth) { 70 case 32: 71 return {"i32", "d"}; 72 case 64: 73 return {"i64", "q"}; 74 default: 75 llvm_unreachable("Invalid ElemWidth"); 76 } 77 } 78 std::string getScalarType() const { return getScalarInfo().first; } 79 std::string getScalarExt() const { return getScalarInfo().second; } 80 std::string getVecType() const { 81 return "<" + Twine(NElem).str() + " x " + getScalarType() + ">"; 82 }; 83 84 std::string getVecWidth() const { return Twine(NElem * ElemWidth).str(); } 85 std::string getFunctionName() const { 86 return "@llvm.x86.avx512.pternlog." + getScalarExt() + "." + getVecWidth(); 87 } 88 std::string getFunctionDecl() const { 89 return "declare " + getVecType() + getFunctionName() + "(" + getVecType() + 90 ", " + getVecType() + ", " + getVecType() + ", " + "i32 immarg)"; 91 } 92 93 std::string getVecN(unsigned N) const { 94 assert(N < 3); 95 std::string VecStr = getVecType() + " <"; 96 for (unsigned I = 0; I < VecElems[N].size(); ++I) { 97 if (I != 0) 98 VecStr += ", "; 99 VecStr += getScalarType() + " " + Twine(VecElems[N][I]).str(); 100 } 101 return VecStr + ">"; 102 } 103 std::string getFunctionCall() const { 104 return "tail call " + getVecType() + " " + getFunctionName() + "(" + 105 getVecN(0) + ", " + getVecN(1) + ", " + getVecN(2) + ", " + "i32 " + 106 Twine(ImmVal).str() + ")"; 107 } 108 109 std::string getTestText() const { 110 return getFunctionDecl() + "\ndefine " + getVecType() + 111 "@foo() {\n%r = " + getFunctionCall() + "\nret " + getVecType() + 112 " %r\n}\n"; 113 } 114 115 void checkResult(const Value *V) { 116 auto GetValElem = [&](unsigned Idx) -> uint64_t { 117 if (auto *CV = dyn_cast<ConstantDataVector>(V)) 118 return CV->getElementAsInteger(Idx); 119 120 auto *C = dyn_cast<Constant>(V); 121 assert(C); 122 if (C->isNullValue()) 123 return 0; 124 if (C->isAllOnesValue()) 125 return ((~uint64_t(0)) >> (ElemWidth % 64)); 126 if (C->isOneValue()) 127 return 1; 128 129 llvm_unreachable("Unknown constant type"); 130 }; 131 132 auto ComputeBit = [&](uint64_t A, uint64_t B, uint64_t C) -> uint64_t { 133 unsigned BitIdx = ((A & 1) << 2) | ((B & 1) << 1) | (C & 1); 134 return (ImmVal >> BitIdx) & 1; 135 }; 136 137 for (unsigned I = 0; I < NElem; ++I) { 138 uint64_t Expec = 0; 139 uint64_t AEle = VecElems[0][I]; 140 uint64_t BEle = VecElems[1][I]; 141 uint64_t CEle = VecElems[2][I]; 142 for (unsigned J = 0; J < ElemWidth; ++J) { 143 Expec |= ComputeBit(AEle >> J, BEle >> J, CEle >> J) << J; 144 } 145 146 ASSERT_EQ(Expec, GetValElem(I)); 147 } 148 } 149 150 void check(LLVMContext &Ctx, FunctionPassManager &FPM, 151 FunctionAnalysisManager &FAM) { 152 SMDiagnostic Error; 153 std::unique_ptr<Module> M = parseAssemblyString(getTestText(), Error, Ctx); 154 ASSERT_TRUE(M); 155 Function *F = M->getFunction("foo"); 156 ASSERT_TRUE(F); 157 ASSERT_EQ(F->getInstructionCount(), 2u); 158 FAM.clear(); 159 FPM.run(*F, FAM); 160 ASSERT_EQ(F->getInstructionCount(), 1u); 161 ASSERT_EQ(F->size(), 1u); 162 const Instruction *I = F->begin()->getTerminator(); 163 ASSERT_TRUE(I); 164 ASSERT_EQ(I->getNumOperands(), 1u); 165 checkResult(I->getOperand(0)); 166 } 167 }; 168 169 TEST(TernlogTest, TestConstantFolding) { 170 LLVMContext Ctx; 171 FunctionAnalysisManager FAM; 172 FunctionPassManager FPM; 173 PassBuilder PB; 174 LoopAnalysisManager LAM; 175 CGSCCAnalysisManager CGAM; 176 ModuleAnalysisManager MAM; 177 TargetIRAnalysis TIRA = TargetIRAnalysis( 178 [&](const Function &F) { return initTM()->getTargetTransformInfo(F); }); 179 180 FAM.registerPass([&] { return TIRA; }); 181 PB.registerModuleAnalyses(MAM); 182 PB.registerCGSCCAnalyses(CGAM); 183 PB.registerFunctionAnalyses(FAM); 184 PB.registerLoopAnalyses(LAM); 185 PB.crossRegisterProxies(LAM, FAM, CGAM, MAM); 186 187 FPM.addPass(InstCombinePass()); 188 TernTester TT; 189 for (unsigned NElem = 2; NElem < 16; NElem += NElem) { 190 TT.updateNElem(NElem); 191 for (unsigned ElemWidth = 32; ElemWidth <= 64; ElemWidth += ElemWidth) { 192 if (ElemWidth * NElem > 512 || ElemWidth * NElem < 128) 193 continue; 194 TT.updateElemWidth(ElemWidth); 195 TT.RandomizeVecArgs(); 196 for (unsigned Imm = 0; Imm < 256; ++Imm) { 197 TT.updateImm(Imm); 198 TT.check(Ctx, FPM, FAM); 199 } 200 } 201 } 202 } 203 } // namespace llvm 204