xref: /llvm-project/llvm/unittests/Target/X86/TernlogTest.cpp (revision bb3f5e1fed7c6ba733b7f273e93f5d3930976185)
1 //===- LICMTest.cpp - LICM unit tests -------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/TargetTransformInfo.h"
10 #include "llvm/AsmParser/Parser.h"
11 #include "llvm/IR/Module.h"
12 #include "llvm/MC/TargetRegistry.h"
13 #include "llvm/Passes/PassBuilder.h"
14 #include "llvm/Support/TargetSelect.h"
15 #include "llvm/Target/TargetMachine.h"
16 #include "llvm/Transforms/InstCombine/InstCombine.h"
17 
18 #include "gtest/gtest.h"
19 
20 #include <random>
21 
22 namespace llvm {
23 static std::unique_ptr<TargetMachine> initTM() {
24   LLVMInitializeX86TargetInfo();
25   LLVMInitializeX86Target();
26   LLVMInitializeX86TargetMC();
27 
28   auto TT(Triple::normalize("x86_64--"));
29   std::string Error;
30   const Target *TheTarget = TargetRegistry::lookupTarget(TT, Error);
31   return std::unique_ptr<TargetMachine>(
32       TheTarget->createTargetMachine(TT, "", "", TargetOptions(), std::nullopt,
33                                      std::nullopt, CodeGenOptLevel::Default));
34 }
35 
36 struct TernTester {
37   unsigned NElem;
38   unsigned ElemWidth;
39   std::mt19937_64 Rng;
40   unsigned ImmVal;
41   SmallVector<uint64_t, 16> VecElems[3];
42 
43   void updateImm(uint8_t NewImmVal) { ImmVal = NewImmVal; }
44   void updateNElem(unsigned NewNElem) {
45     NElem = NewNElem;
46     for (unsigned I = 0; I < 3; ++I) {
47       VecElems[I].resize(NElem);
48     }
49   }
50   void updateElemWidth(unsigned NewElemWidth) {
51     ElemWidth = NewElemWidth;
52     assert(ElemWidth == 32 || ElemWidth == 64);
53   }
54 
55   uint64_t getElemMask() const {
56     return (~uint64_t(0)) >> ((ElemWidth - 0) % 64);
57   }
58 
59   void RandomizeVecArgs() {
60     uint64_t ElemMask = getElemMask();
61     for (unsigned I = 0; I < 3; ++I) {
62       for (unsigned J = 0; J < NElem; ++J) {
63         VecElems[I][J] = Rng() & ElemMask;
64       }
65     }
66   }
67 
68   std::pair<std::string, std::string> getScalarInfo() const {
69     switch (ElemWidth) {
70     case 32:
71       return {"i32", "d"};
72     case 64:
73       return {"i64", "q"};
74     default:
75       llvm_unreachable("Invalid ElemWidth");
76     }
77   }
78   std::string getScalarType() const { return getScalarInfo().first; }
79   std::string getScalarExt() const { return getScalarInfo().second; }
80   std::string getVecType() const {
81     return "<" + Twine(NElem).str() + " x " + getScalarType() + ">";
82   };
83 
84   std::string getVecWidth() const { return Twine(NElem * ElemWidth).str(); }
85   std::string getFunctionName() const {
86     return "@llvm.x86.avx512.pternlog." + getScalarExt() + "." + getVecWidth();
87   }
88   std::string getFunctionDecl() const {
89     return "declare " + getVecType() + getFunctionName() + "(" + getVecType() +
90            ", " + getVecType() + ", " + getVecType() + ", " + "i32 immarg)";
91   }
92 
93   std::string getVecN(unsigned N) const {
94     assert(N < 3);
95     std::string VecStr = getVecType() + " <";
96     for (unsigned I = 0; I < VecElems[N].size(); ++I) {
97       if (I != 0)
98         VecStr += ", ";
99       VecStr += getScalarType() + " " + Twine(VecElems[N][I]).str();
100     }
101     return VecStr + ">";
102   }
103   std::string getFunctionCall() const {
104     return "tail call " + getVecType() + " " + getFunctionName() + "(" +
105            getVecN(0) + ", " + getVecN(1) + ", " + getVecN(2) + ", " + "i32 " +
106            Twine(ImmVal).str() + ")";
107   }
108 
109   std::string getTestText() const {
110     return getFunctionDecl() + "\ndefine " + getVecType() +
111            "@foo() {\n%r = " + getFunctionCall() + "\nret " + getVecType() +
112            " %r\n}\n";
113   }
114 
115   void checkResult(const Value *V) {
116     auto GetValElem = [&](unsigned Idx) -> uint64_t {
117       if (auto *CV = dyn_cast<ConstantDataVector>(V))
118         return CV->getElementAsInteger(Idx);
119 
120       auto *C = dyn_cast<Constant>(V);
121       assert(C);
122       if (C->isNullValue())
123         return 0;
124       if (C->isAllOnesValue())
125         return ((~uint64_t(0)) >> (ElemWidth % 64));
126       if (C->isOneValue())
127         return 1;
128 
129       llvm_unreachable("Unknown constant type");
130     };
131 
132     auto ComputeBit = [&](uint64_t A, uint64_t B, uint64_t C) -> uint64_t {
133       unsigned BitIdx = ((A & 1) << 2) | ((B & 1) << 1) | (C & 1);
134       return (ImmVal >> BitIdx) & 1;
135     };
136 
137     for (unsigned I = 0; I < NElem; ++I) {
138       uint64_t Expec = 0;
139       uint64_t AEle = VecElems[0][I];
140       uint64_t BEle = VecElems[1][I];
141       uint64_t CEle = VecElems[2][I];
142       for (unsigned J = 0; J < ElemWidth; ++J) {
143         Expec |= ComputeBit(AEle >> J, BEle >> J, CEle >> J) << J;
144       }
145 
146       ASSERT_EQ(Expec, GetValElem(I));
147     }
148   }
149 
150   void check(LLVMContext &Ctx, FunctionPassManager &FPM,
151              FunctionAnalysisManager &FAM) {
152     SMDiagnostic Error;
153     std::unique_ptr<Module> M = parseAssemblyString(getTestText(), Error, Ctx);
154     ASSERT_TRUE(M);
155     Function *F = M->getFunction("foo");
156     ASSERT_TRUE(F);
157     ASSERT_EQ(F->getInstructionCount(), 2u);
158     FAM.clear();
159     FPM.run(*F, FAM);
160     ASSERT_EQ(F->getInstructionCount(), 1u);
161     ASSERT_EQ(F->size(), 1u);
162     const Instruction *I = F->begin()->getTerminator();
163     ASSERT_TRUE(I);
164     ASSERT_EQ(I->getNumOperands(), 1u);
165     checkResult(I->getOperand(0));
166   }
167 };
168 
169 TEST(TernlogTest, TestConstantFolding) {
170   LLVMContext Ctx;
171   FunctionAnalysisManager FAM;
172   FunctionPassManager FPM;
173   PassBuilder PB;
174   LoopAnalysisManager LAM;
175   CGSCCAnalysisManager CGAM;
176   ModuleAnalysisManager MAM;
177   TargetIRAnalysis TIRA = TargetIRAnalysis(
178       [&](const Function &F) { return initTM()->getTargetTransformInfo(F); });
179 
180   FAM.registerPass([&] { return TIRA; });
181   PB.registerModuleAnalyses(MAM);
182   PB.registerCGSCCAnalyses(CGAM);
183   PB.registerFunctionAnalyses(FAM);
184   PB.registerLoopAnalyses(LAM);
185   PB.crossRegisterProxies(LAM, FAM, CGAM, MAM);
186 
187   FPM.addPass(InstCombinePass());
188   TernTester TT;
189   for (unsigned NElem = 2; NElem < 16; NElem += NElem) {
190     TT.updateNElem(NElem);
191     for (unsigned ElemWidth = 32; ElemWidth <= 64; ElemWidth += ElemWidth) {
192       if (ElemWidth * NElem > 512 || ElemWidth * NElem < 128)
193         continue;
194       TT.updateElemWidth(ElemWidth);
195       TT.RandomizeVecArgs();
196       for (unsigned Imm = 0; Imm < 256; ++Imm) {
197         TT.updateImm(Imm);
198         TT.check(Ctx, FPM, FAM);
199       }
200     }
201   }
202 }
203 } // namespace llvm
204