xref: /llvm-project/llvm/unittests/Target/AMDGPU/AMDGPUUnitTests.cpp (revision 53697a5dcdc4d83cbe0cb6d88e33c3f1bb3ea487)
1 //===--------- llvm/unittests/Target/AMDGPU/AMDGPUUnitTests.cpp -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "AMDGPUUnitTests.h"
10 #include "AMDGPUTargetMachine.h"
11 #include "GCNSubtarget.h"
12 #include "llvm/MC/TargetRegistry.h"
13 #include "llvm/Support/TargetSelect.h"
14 #include "llvm/TargetParser/TargetParser.h"
15 #include "gtest/gtest.h"
16 
17 #include "AMDGPUGenSubtargetInfo.inc"
18 
19 using namespace llvm;
20 
21 std::once_flag flag;
22 
InitializeAMDGPUTarget()23 void InitializeAMDGPUTarget() {
24   std::call_once(flag, []() {
25     LLVMInitializeAMDGPUTargetInfo();
26     LLVMInitializeAMDGPUTarget();
27     LLVMInitializeAMDGPUTargetMC();
28   });
29 }
30 
31 std::unique_ptr<const GCNTargetMachine>
createAMDGPUTargetMachine(std::string TStr,StringRef CPU,StringRef FS)32 llvm::createAMDGPUTargetMachine(std::string TStr, StringRef CPU, StringRef FS) {
33   InitializeAMDGPUTarget();
34 
35   std::string Error;
36   const Target *T = TargetRegistry::lookupTarget(TStr, Error);
37   if (!T)
38     return nullptr;
39 
40   TargetOptions Options;
41   return std::unique_ptr<GCNTargetMachine>(
42       static_cast<GCNTargetMachine *>(T->createTargetMachine(
43           TStr, CPU, FS, Options, std::nullopt, std::nullopt)));
44 }
45 
46 static cl::opt<bool> PrintCpuRegLimits(
47     "print-cpu-reg-limits", cl::NotHidden, cl::init(false),
48     cl::desc("force printing per AMDGPU CPU register limits"));
49 
checkMinMax(std::stringstream & OS,unsigned Occ,unsigned MinOcc,unsigned MaxOcc,std::function<unsigned (unsigned)> GetOcc,std::function<unsigned (unsigned)> GetMinGPRs,std::function<unsigned (unsigned)> GetMaxGPRs)50 static bool checkMinMax(std::stringstream &OS, unsigned Occ, unsigned MinOcc,
51                         unsigned MaxOcc,
52                         std::function<unsigned(unsigned)> GetOcc,
53                         std::function<unsigned(unsigned)> GetMinGPRs,
54                         std::function<unsigned(unsigned)> GetMaxGPRs) {
55   bool MinValid = true, MaxValid = true, RangeValid = true;
56   unsigned MinGPRs = GetMinGPRs(Occ);
57   unsigned MaxGPRs = GetMaxGPRs(Occ);
58   unsigned RealOcc;
59 
60   if (MinGPRs >= MaxGPRs)
61     RangeValid = false;
62   else {
63     RealOcc = GetOcc(MinGPRs);
64     for (unsigned NumRegs = MinGPRs + 1; NumRegs <= MaxGPRs; ++NumRegs) {
65       if (RealOcc != GetOcc(NumRegs)) {
66         RangeValid = false;
67         break;
68       }
69     }
70   }
71 
72   if (RangeValid && RealOcc > MinOcc && RealOcc <= MaxOcc) {
73     if (MinGPRs > 0 && GetOcc(MinGPRs - 1) <= RealOcc)
74       MinValid = false;
75 
76     if (GetOcc(MaxGPRs + 1) >= RealOcc)
77       MaxValid = false;
78   }
79 
80   std::stringstream MinStr;
81   MinStr << (MinValid ? ' ' : '<') << ' ' << std::setw(3) << MinGPRs << " (O"
82          << GetOcc(MinGPRs) << ") " << (RangeValid ? ' ' : 'R');
83 
84   OS << std::left << std::setw(15) << MinStr.str() << std::setw(3) << MaxGPRs
85      << " (O" << GetOcc(MaxGPRs) << ')' << (MaxValid ? "" : " >");
86 
87   return MinValid && MaxValid && RangeValid;
88 }
89 
90 static const std::pair<StringRef, StringRef>
91   EmptyFS = {"", ""},
92   W32FS = {"+wavefrontsize32", "w32"},
93   W64FS = {"+wavefrontsize64", "w64"};
94 
95 using TestFuncTy =
96     function_ref<bool(std::stringstream &, unsigned, const GCNSubtarget &)>;
97 
testAndRecord(std::stringstream & Table,const GCNSubtarget & ST,TestFuncTy test)98 static bool testAndRecord(std::stringstream &Table, const GCNSubtarget &ST,
99                           TestFuncTy test) {
100   bool Success = true;
101   unsigned MaxOcc = ST.getMaxWavesPerEU();
102   for (unsigned Occ = MaxOcc; Occ > 0; --Occ) {
103     Table << std::right << std::setw(3) << Occ << "    ";
104     Success = test(Table, Occ, ST) && Success;
105     Table << '\n';
106   }
107   return Success;
108 }
109 
testGPRLimits(const char * RegName,bool TestW32W64,TestFuncTy test)110 static void testGPRLimits(const char *RegName, bool TestW32W64,
111                           TestFuncTy test) {
112   SmallVector<StringRef> CPUs;
113   AMDGPU::fillValidArchListAMDGCN(CPUs);
114 
115   std::map<std::string, SmallVector<std::string>> TablePerCPUs;
116   for (auto CPUName : CPUs) {
117     auto CanonCPUName =
118         AMDGPU::getArchNameAMDGCN(AMDGPU::parseArchAMDGCN(CPUName));
119 
120     auto *FS = &EmptyFS;
121     while (true) {
122       auto TM = createAMDGPUTargetMachine("amdgcn-amd-", CPUName, FS->first);
123       if (!TM)
124         break;
125 
126       GCNSubtarget ST(TM->getTargetTriple(), std::string(TM->getTargetCPU()),
127                       std::string(TM->getTargetFeatureString()), *TM);
128 
129       if (TestW32W64 &&
130           ST.getFeatureBits().test(AMDGPU::FeatureWavefrontSize32))
131         FS = &W32FS;
132 
133       std::stringstream Table;
134       bool Success = testAndRecord(Table, ST, test);
135       if (!Success || PrintCpuRegLimits)
136         TablePerCPUs[Table.str()].push_back((CanonCPUName + FS->second).str());
137 
138       if (FS != &W32FS)
139         break;
140 
141       FS = &W64FS;
142     }
143   }
144   std::stringstream OS;
145   for (auto &P : TablePerCPUs) {
146     for (auto &CPUName : P.second)
147       OS << ' ' << CPUName;
148     OS << ":\nOcc    Min" << RegName << "        Max" << RegName << '\n'
149        << P.first << '\n';
150   }
151   auto ErrStr = OS.str();
152   EXPECT_TRUE(ErrStr.empty()) << ErrStr;
153 }
154 
TEST(AMDGPU,TestVGPRLimitsPerOccupancy)155 TEST(AMDGPU, TestVGPRLimitsPerOccupancy) {
156   auto test = [](std::stringstream &OS, unsigned Occ, const GCNSubtarget &ST) {
157     unsigned MaxVGPRNum = ST.getAddressableNumVGPRs();
158     return checkMinMax(
159         OS, Occ, ST.getOccupancyWithNumVGPRs(MaxVGPRNum), ST.getMaxWavesPerEU(),
160         [&](unsigned NumGPRs) { return ST.getOccupancyWithNumVGPRs(NumGPRs); },
161         [&](unsigned Occ) { return ST.getMinNumVGPRs(Occ); },
162         [&](unsigned Occ) { return ST.getMaxNumVGPRs(Occ); });
163   };
164 
165   testGPRLimits("VGPR", true, test);
166 }
167