xref: /llvm-project/llvm/unittests/Target/X86/TernlogTest.cpp (revision bb3f5e1fed7c6ba733b7f273e93f5d3930976185)
1443825c5SNoah Goldstein //===- LICMTest.cpp - LICM unit tests -------------------------------------===//
2443825c5SNoah Goldstein //
3443825c5SNoah Goldstein // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4443825c5SNoah Goldstein // See https://llvm.org/LICENSE.txt for license information.
5443825c5SNoah Goldstein // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6443825c5SNoah Goldstein //
7443825c5SNoah Goldstein //===----------------------------------------------------------------------===//
8443825c5SNoah Goldstein 
9443825c5SNoah Goldstein #include "llvm/Analysis/TargetTransformInfo.h"
10443825c5SNoah Goldstein #include "llvm/AsmParser/Parser.h"
1174deadf1SNikita Popov #include "llvm/IR/Module.h"
12443825c5SNoah Goldstein #include "llvm/MC/TargetRegistry.h"
13443825c5SNoah Goldstein #include "llvm/Passes/PassBuilder.h"
14443825c5SNoah Goldstein #include "llvm/Support/TargetSelect.h"
15443825c5SNoah Goldstein #include "llvm/Target/TargetMachine.h"
16443825c5SNoah Goldstein #include "llvm/Transforms/InstCombine/InstCombine.h"
17443825c5SNoah Goldstein 
18443825c5SNoah Goldstein #include "gtest/gtest.h"
19443825c5SNoah Goldstein 
20443825c5SNoah Goldstein #include <random>
21443825c5SNoah Goldstein 
22443825c5SNoah Goldstein namespace llvm {
23*bb3f5e1fSMatin Raayai static std::unique_ptr<TargetMachine> initTM() {
24443825c5SNoah Goldstein   LLVMInitializeX86TargetInfo();
25443825c5SNoah Goldstein   LLVMInitializeX86Target();
26443825c5SNoah Goldstein   LLVMInitializeX86TargetMC();
27443825c5SNoah Goldstein 
28443825c5SNoah Goldstein   auto TT(Triple::normalize("x86_64--"));
29443825c5SNoah Goldstein   std::string Error;
30443825c5SNoah Goldstein   const Target *TheTarget = TargetRegistry::lookupTarget(TT, Error);
31*bb3f5e1fSMatin Raayai   return std::unique_ptr<TargetMachine>(
32443825c5SNoah Goldstein       TheTarget->createTargetMachine(TT, "", "", TargetOptions(), std::nullopt,
33*bb3f5e1fSMatin Raayai                                      std::nullopt, CodeGenOptLevel::Default));
34443825c5SNoah Goldstein }
35443825c5SNoah Goldstein 
36443825c5SNoah Goldstein struct TernTester {
37443825c5SNoah Goldstein   unsigned NElem;
38443825c5SNoah Goldstein   unsigned ElemWidth;
39443825c5SNoah Goldstein   std::mt19937_64 Rng;
40443825c5SNoah Goldstein   unsigned ImmVal;
41443825c5SNoah Goldstein   SmallVector<uint64_t, 16> VecElems[3];
42443825c5SNoah Goldstein 
43443825c5SNoah Goldstein   void updateImm(uint8_t NewImmVal) { ImmVal = NewImmVal; }
44443825c5SNoah Goldstein   void updateNElem(unsigned NewNElem) {
45443825c5SNoah Goldstein     NElem = NewNElem;
46443825c5SNoah Goldstein     for (unsigned I = 0; I < 3; ++I) {
47443825c5SNoah Goldstein       VecElems[I].resize(NElem);
48443825c5SNoah Goldstein     }
49443825c5SNoah Goldstein   }
50443825c5SNoah Goldstein   void updateElemWidth(unsigned NewElemWidth) {
51443825c5SNoah Goldstein     ElemWidth = NewElemWidth;
52443825c5SNoah Goldstein     assert(ElemWidth == 32 || ElemWidth == 64);
53443825c5SNoah Goldstein   }
54443825c5SNoah Goldstein 
55443825c5SNoah Goldstein   uint64_t getElemMask() const {
56443825c5SNoah Goldstein     return (~uint64_t(0)) >> ((ElemWidth - 0) % 64);
57443825c5SNoah Goldstein   }
58443825c5SNoah Goldstein 
59443825c5SNoah Goldstein   void RandomizeVecArgs() {
60443825c5SNoah Goldstein     uint64_t ElemMask = getElemMask();
61443825c5SNoah Goldstein     for (unsigned I = 0; I < 3; ++I) {
62443825c5SNoah Goldstein       for (unsigned J = 0; J < NElem; ++J) {
63443825c5SNoah Goldstein         VecElems[I][J] = Rng() & ElemMask;
64443825c5SNoah Goldstein       }
65443825c5SNoah Goldstein     }
66443825c5SNoah Goldstein   }
67443825c5SNoah Goldstein 
68443825c5SNoah Goldstein   std::pair<std::string, std::string> getScalarInfo() const {
69443825c5SNoah Goldstein     switch (ElemWidth) {
70443825c5SNoah Goldstein     case 32:
71443825c5SNoah Goldstein       return {"i32", "d"};
72443825c5SNoah Goldstein     case 64:
73443825c5SNoah Goldstein       return {"i64", "q"};
74443825c5SNoah Goldstein     default:
75443825c5SNoah Goldstein       llvm_unreachable("Invalid ElemWidth");
76443825c5SNoah Goldstein     }
77443825c5SNoah Goldstein   }
78443825c5SNoah Goldstein   std::string getScalarType() const { return getScalarInfo().first; }
79443825c5SNoah Goldstein   std::string getScalarExt() const { return getScalarInfo().second; }
80443825c5SNoah Goldstein   std::string getVecType() const {
81443825c5SNoah Goldstein     return "<" + Twine(NElem).str() + " x " + getScalarType() + ">";
82443825c5SNoah Goldstein   };
83443825c5SNoah Goldstein 
84443825c5SNoah Goldstein   std::string getVecWidth() const { return Twine(NElem * ElemWidth).str(); }
85443825c5SNoah Goldstein   std::string getFunctionName() const {
86443825c5SNoah Goldstein     return "@llvm.x86.avx512.pternlog." + getScalarExt() + "." + getVecWidth();
87443825c5SNoah Goldstein   }
88443825c5SNoah Goldstein   std::string getFunctionDecl() const {
89443825c5SNoah Goldstein     return "declare " + getVecType() + getFunctionName() + "(" + getVecType() +
90443825c5SNoah Goldstein            ", " + getVecType() + ", " + getVecType() + ", " + "i32 immarg)";
91443825c5SNoah Goldstein   }
92443825c5SNoah Goldstein 
93443825c5SNoah Goldstein   std::string getVecN(unsigned N) const {
94443825c5SNoah Goldstein     assert(N < 3);
95443825c5SNoah Goldstein     std::string VecStr = getVecType() + " <";
96443825c5SNoah Goldstein     for (unsigned I = 0; I < VecElems[N].size(); ++I) {
97443825c5SNoah Goldstein       if (I != 0)
98443825c5SNoah Goldstein         VecStr += ", ";
99443825c5SNoah Goldstein       VecStr += getScalarType() + " " + Twine(VecElems[N][I]).str();
100443825c5SNoah Goldstein     }
101443825c5SNoah Goldstein     return VecStr + ">";
102443825c5SNoah Goldstein   }
103443825c5SNoah Goldstein   std::string getFunctionCall() const {
104443825c5SNoah Goldstein     return "tail call " + getVecType() + " " + getFunctionName() + "(" +
105443825c5SNoah Goldstein            getVecN(0) + ", " + getVecN(1) + ", " + getVecN(2) + ", " + "i32 " +
106443825c5SNoah Goldstein            Twine(ImmVal).str() + ")";
107443825c5SNoah Goldstein   }
108443825c5SNoah Goldstein 
109443825c5SNoah Goldstein   std::string getTestText() const {
110443825c5SNoah Goldstein     return getFunctionDecl() + "\ndefine " + getVecType() +
111443825c5SNoah Goldstein            "@foo() {\n%r = " + getFunctionCall() + "\nret " + getVecType() +
112443825c5SNoah Goldstein            " %r\n}\n";
113443825c5SNoah Goldstein   }
114443825c5SNoah Goldstein 
115443825c5SNoah Goldstein   void checkResult(const Value *V) {
116443825c5SNoah Goldstein     auto GetValElem = [&](unsigned Idx) -> uint64_t {
117443825c5SNoah Goldstein       if (auto *CV = dyn_cast<ConstantDataVector>(V))
118443825c5SNoah Goldstein         return CV->getElementAsInteger(Idx);
119443825c5SNoah Goldstein 
120443825c5SNoah Goldstein       auto *C = dyn_cast<Constant>(V);
121443825c5SNoah Goldstein       assert(C);
122443825c5SNoah Goldstein       if (C->isNullValue())
123443825c5SNoah Goldstein         return 0;
124443825c5SNoah Goldstein       if (C->isAllOnesValue())
125443825c5SNoah Goldstein         return ((~uint64_t(0)) >> (ElemWidth % 64));
126443825c5SNoah Goldstein       if (C->isOneValue())
127443825c5SNoah Goldstein         return 1;
128443825c5SNoah Goldstein 
129443825c5SNoah Goldstein       llvm_unreachable("Unknown constant type");
130443825c5SNoah Goldstein     };
131443825c5SNoah Goldstein 
132443825c5SNoah Goldstein     auto ComputeBit = [&](uint64_t A, uint64_t B, uint64_t C) -> uint64_t {
133443825c5SNoah Goldstein       unsigned BitIdx = ((A & 1) << 2) | ((B & 1) << 1) | (C & 1);
134443825c5SNoah Goldstein       return (ImmVal >> BitIdx) & 1;
135443825c5SNoah Goldstein     };
136443825c5SNoah Goldstein 
137443825c5SNoah Goldstein     for (unsigned I = 0; I < NElem; ++I) {
138443825c5SNoah Goldstein       uint64_t Expec = 0;
139443825c5SNoah Goldstein       uint64_t AEle = VecElems[0][I];
140443825c5SNoah Goldstein       uint64_t BEle = VecElems[1][I];
141443825c5SNoah Goldstein       uint64_t CEle = VecElems[2][I];
142443825c5SNoah Goldstein       for (unsigned J = 0; J < ElemWidth; ++J) {
143443825c5SNoah Goldstein         Expec |= ComputeBit(AEle >> J, BEle >> J, CEle >> J) << J;
144443825c5SNoah Goldstein       }
145443825c5SNoah Goldstein 
146443825c5SNoah Goldstein       ASSERT_EQ(Expec, GetValElem(I));
147443825c5SNoah Goldstein     }
148443825c5SNoah Goldstein   }
149443825c5SNoah Goldstein 
150443825c5SNoah Goldstein   void check(LLVMContext &Ctx, FunctionPassManager &FPM,
151443825c5SNoah Goldstein              FunctionAnalysisManager &FAM) {
152443825c5SNoah Goldstein     SMDiagnostic Error;
153443825c5SNoah Goldstein     std::unique_ptr<Module> M = parseAssemblyString(getTestText(), Error, Ctx);
154443825c5SNoah Goldstein     ASSERT_TRUE(M);
155443825c5SNoah Goldstein     Function *F = M->getFunction("foo");
156443825c5SNoah Goldstein     ASSERT_TRUE(F);
157443825c5SNoah Goldstein     ASSERT_EQ(F->getInstructionCount(), 2u);
158cacbe71aSYingwei Zheng     FAM.clear();
159443825c5SNoah Goldstein     FPM.run(*F, FAM);
160443825c5SNoah Goldstein     ASSERT_EQ(F->getInstructionCount(), 1u);
161443825c5SNoah Goldstein     ASSERT_EQ(F->size(), 1u);
162443825c5SNoah Goldstein     const Instruction *I = F->begin()->getTerminator();
163443825c5SNoah Goldstein     ASSERT_TRUE(I);
164443825c5SNoah Goldstein     ASSERT_EQ(I->getNumOperands(), 1u);
165443825c5SNoah Goldstein     checkResult(I->getOperand(0));
166443825c5SNoah Goldstein   }
167443825c5SNoah Goldstein };
168443825c5SNoah Goldstein 
169443825c5SNoah Goldstein TEST(TernlogTest, TestConstantFolding) {
170443825c5SNoah Goldstein   LLVMContext Ctx;
171443825c5SNoah Goldstein   FunctionAnalysisManager FAM;
172443825c5SNoah Goldstein   FunctionPassManager FPM;
173443825c5SNoah Goldstein   PassBuilder PB;
174443825c5SNoah Goldstein   LoopAnalysisManager LAM;
175443825c5SNoah Goldstein   CGSCCAnalysisManager CGAM;
176443825c5SNoah Goldstein   ModuleAnalysisManager MAM;
177443825c5SNoah Goldstein   TargetIRAnalysis TIRA = TargetIRAnalysis(
178443825c5SNoah Goldstein       [&](const Function &F) { return initTM()->getTargetTransformInfo(F); });
179443825c5SNoah Goldstein 
180443825c5SNoah Goldstein   FAM.registerPass([&] { return TIRA; });
181443825c5SNoah Goldstein   PB.registerModuleAnalyses(MAM);
182443825c5SNoah Goldstein   PB.registerCGSCCAnalyses(CGAM);
183443825c5SNoah Goldstein   PB.registerFunctionAnalyses(FAM);
184443825c5SNoah Goldstein   PB.registerLoopAnalyses(LAM);
185443825c5SNoah Goldstein   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
186443825c5SNoah Goldstein 
187443825c5SNoah Goldstein   FPM.addPass(InstCombinePass());
188443825c5SNoah Goldstein   TernTester TT;
189443825c5SNoah Goldstein   for (unsigned NElem = 2; NElem < 16; NElem += NElem) {
190443825c5SNoah Goldstein     TT.updateNElem(NElem);
191443825c5SNoah Goldstein     for (unsigned ElemWidth = 32; ElemWidth <= 64; ElemWidth += ElemWidth) {
192443825c5SNoah Goldstein       if (ElemWidth * NElem > 512 || ElemWidth * NElem < 128)
193443825c5SNoah Goldstein         continue;
194443825c5SNoah Goldstein       TT.updateElemWidth(ElemWidth);
195443825c5SNoah Goldstein       TT.RandomizeVecArgs();
196443825c5SNoah Goldstein       for (unsigned Imm = 0; Imm < 256; ++Imm) {
197443825c5SNoah Goldstein         TT.updateImm(Imm);
198443825c5SNoah Goldstein         TT.check(Ctx, FPM, FAM);
199443825c5SNoah Goldstein       }
200443825c5SNoah Goldstein     }
201443825c5SNoah Goldstein   }
202443825c5SNoah Goldstein }
203443825c5SNoah Goldstein } // namespace llvm
204