xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp (revision 3ae4c3589ec7336d363fc1779c4a99360164c8f4)
1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This pass does attempts to make use of reqd_work_group_size metadata
10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11 /// get_local_size-like functions.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "Utils/AMDGPUBaseInfo.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/CodeGen/TargetPassConfig.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicsAMDGPU.h"
25 #include "llvm/IR/PatternMatch.h"
26 #include "llvm/Pass.h"
27 
28 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
29 
30 using namespace llvm;
31 
32 namespace {
33 
34 // Field offsets in hsa_kernel_dispatch_packet_t.
35 enum DispatchPackedOffsets {
36   WORKGROUP_SIZE_X = 4,
37   WORKGROUP_SIZE_Y = 6,
38   WORKGROUP_SIZE_Z = 8,
39 
40   GRID_SIZE_X = 12,
41   GRID_SIZE_Y = 16,
42   GRID_SIZE_Z = 20
43 };
44 
45 // Field offsets to implicit kernel argument pointer.
46 enum ImplicitArgOffsets {
47   HIDDEN_BLOCK_COUNT_X = 0,
48   HIDDEN_BLOCK_COUNT_Y = 4,
49   HIDDEN_BLOCK_COUNT_Z = 8,
50 
51   HIDDEN_GROUP_SIZE_X = 12,
52   HIDDEN_GROUP_SIZE_Y = 14,
53   HIDDEN_GROUP_SIZE_Z = 16,
54 
55   HIDDEN_REMAINDER_X = 18,
56   HIDDEN_REMAINDER_Y = 20,
57   HIDDEN_REMAINDER_Z = 22,
58 };
59 
60 class AMDGPULowerKernelAttributes : public ModulePass {
61 public:
62   static char ID;
63 
64   AMDGPULowerKernelAttributes() : ModulePass(ID) {}
65 
66   bool runOnModule(Module &M) override;
67 
68   StringRef getPassName() const override {
69     return "AMDGPU Kernel Attributes";
70   }
71 
72   void getAnalysisUsage(AnalysisUsage &AU) const override {
73     AU.setPreservesAll();
74  }
75 };
76 
77 } // end anonymous namespace
78 
79 static bool processUse(CallInst *CI, bool IsV5OrAbove) {
80   Function *F = CI->getParent()->getParent();
81 
82   auto MD = F->getMetadata("reqd_work_group_size");
83   const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
84 
85   const bool HasUniformWorkGroupSize =
86     F->getFnAttribute("uniform-work-group-size").getValueAsBool();
87 
88   if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
89     return false;
90 
91   Value *BlockCounts[3] = {nullptr, nullptr, nullptr};
92   Value *GroupSizes[3]  = {nullptr, nullptr, nullptr};
93   Value *Remainders[3]  = {nullptr, nullptr, nullptr};
94   Value *GridSizes[3]   = {nullptr, nullptr, nullptr};
95 
96   const DataLayout &DL = F->getParent()->getDataLayout();
97 
98   // We expect to see several GEP users, casted to the appropriate type and
99   // loaded.
100   for (User *U : CI->users()) {
101     if (!U->hasOneUse())
102       continue;
103 
104     int64_t Offset = 0;
105     BitCastInst *BCI = dyn_cast<BitCastInst>(U);
106     if (!BCI) {
107       if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
108         continue;
109       BCI = dyn_cast<BitCastInst>(*U->user_begin());
110     }
111 
112     if (!BCI || !BCI->hasOneUse())
113       continue;
114 
115     auto *Load = dyn_cast<LoadInst>(*BCI->user_begin());
116     if (!Load || !Load->isSimple())
117       continue;
118 
119     unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
120 
121     // TODO: Handle merged loads.
122     if (IsV5OrAbove) { // Base is ImplicitArgPtr.
123       switch (Offset) {
124       case HIDDEN_BLOCK_COUNT_X:
125         if (LoadSize == 4)
126           BlockCounts[0] = Load;
127         break;
128       case HIDDEN_BLOCK_COUNT_Y:
129         if (LoadSize == 4)
130           BlockCounts[1] = Load;
131         break;
132       case HIDDEN_BLOCK_COUNT_Z:
133         if (LoadSize == 4)
134           BlockCounts[2] = Load;
135         break;
136       case HIDDEN_GROUP_SIZE_X:
137         if (LoadSize == 2)
138           GroupSizes[0] = Load;
139         break;
140       case HIDDEN_GROUP_SIZE_Y:
141         if (LoadSize == 2)
142           GroupSizes[1] = Load;
143         break;
144       case HIDDEN_GROUP_SIZE_Z:
145         if (LoadSize == 2)
146           GroupSizes[2] = Load;
147         break;
148       case HIDDEN_REMAINDER_X:
149         if (LoadSize == 2)
150           Remainders[0] = Load;
151         break;
152       case HIDDEN_REMAINDER_Y:
153         if (LoadSize == 2)
154           Remainders[1] = Load;
155         break;
156       case HIDDEN_REMAINDER_Z:
157         if (LoadSize == 2)
158           Remainders[2] = Load;
159         break;
160       default:
161         break;
162       }
163     } else { // Base is DispatchPtr.
164       switch (Offset) {
165       case WORKGROUP_SIZE_X:
166         if (LoadSize == 2)
167           GroupSizes[0] = Load;
168         break;
169       case WORKGROUP_SIZE_Y:
170         if (LoadSize == 2)
171           GroupSizes[1] = Load;
172         break;
173       case WORKGROUP_SIZE_Z:
174         if (LoadSize == 2)
175           GroupSizes[2] = Load;
176         break;
177       case GRID_SIZE_X:
178         if (LoadSize == 4)
179           GridSizes[0] = Load;
180         break;
181       case GRID_SIZE_Y:
182         if (LoadSize == 4)
183           GridSizes[1] = Load;
184         break;
185       case GRID_SIZE_Z:
186         if (LoadSize == 4)
187           GridSizes[2] = Load;
188         break;
189       default:
190         break;
191       }
192     }
193   }
194 
195   bool MadeChange = false;
196   if (IsV5OrAbove && HasUniformWorkGroupSize) {
197     // Under v5  __ockl_get_local_size returns the value computed by the expression:
198     //
199     //   workgroup_id < hidden_block_count ? hidden_group_size : hidden_remainder
200     //
201     // For functions with the attribute uniform-work-group-size=true. we can evaluate
202     // workgroup_id < hidden_block_count as true, and thus hidden_group_size is returned
203     // for __ockl_get_local_size.
204     for (int I = 0; I < 3; ++I) {
205       Value *BlockCount = BlockCounts[I];
206       if (!BlockCount)
207         continue;
208 
209       using namespace llvm::PatternMatch;
210       auto GroupIDIntrin =
211           I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
212                  : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
213                            : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
214 
215       for (User *ICmp : BlockCount->users()) {
216         ICmpInst::Predicate Pred;
217         if (match(ICmp, m_ICmp(Pred, GroupIDIntrin, m_Specific(BlockCount)))) {
218           if (Pred != ICmpInst::ICMP_ULT)
219             continue;
220           ICmp->replaceAllUsesWith(llvm::ConstantInt::getTrue(ICmp->getType()));
221           MadeChange = true;
222         }
223       }
224     }
225 
226     // All remainders should be 0 with uniform work group size.
227     for (Value *Remainder : Remainders) {
228       if (!Remainder)
229         continue;
230       Remainder->replaceAllUsesWith(Constant::getNullValue(Remainder->getType()));
231       MadeChange = true;
232     }
233   } else if (HasUniformWorkGroupSize) { // Pre-V5.
234     // Pattern match the code used to handle partial workgroup dispatches in the
235     // library implementation of get_local_size, so the entire function can be
236     // constant folded with a known group size.
237     //
238     // uint r = grid_size - group_id * group_size;
239     // get_local_size = (r < group_size) ? r : group_size;
240     //
241     // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
242     // the grid_size is required to be a multiple of group_size). In this case:
243     //
244     // grid_size - (group_id * group_size) < group_size
245     // ->
246     // grid_size < group_size + (group_id * group_size)
247     //
248     // (grid_size / group_size) < 1 + group_id
249     //
250     // grid_size / group_size is at least 1, so we can conclude the select
251     // condition is false (except for group_id == 0, where the select result is
252     // the same).
253     for (int I = 0; I < 3; ++I) {
254       Value *GroupSize = GroupSizes[I];
255       Value *GridSize = GridSizes[I];
256       if (!GroupSize || !GridSize)
257         continue;
258 
259       using namespace llvm::PatternMatch;
260       auto GroupIDIntrin =
261           I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
262                  : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
263                            : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
264 
265       for (User *U : GroupSize->users()) {
266         auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
267         if (!ZextGroupSize)
268           continue;
269 
270         for (User *UMin : ZextGroupSize->users()) {
271           if (match(UMin,
272                     m_UMin(m_Sub(m_Specific(GridSize),
273                                  m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))),
274                            m_Specific(ZextGroupSize)))) {
275             if (HasReqdWorkGroupSize) {
276               ConstantInt *KnownSize
277                 = mdconst::extract<ConstantInt>(MD->getOperand(I));
278               UMin->replaceAllUsesWith(ConstantExpr::getIntegerCast(
279                   KnownSize, UMin->getType(), false));
280             } else {
281               UMin->replaceAllUsesWith(ZextGroupSize);
282             }
283 
284             MadeChange = true;
285           }
286         }
287       }
288     }
289   }
290 
291   // If reqd_work_group_size is set, we can replace work group size with it.
292   if (!HasReqdWorkGroupSize)
293     return MadeChange;
294 
295   for (int I = 0; I < 3; I++) {
296     Value *GroupSize = GroupSizes[I];
297     if (!GroupSize)
298       continue;
299 
300     ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
301     GroupSize->replaceAllUsesWith(
302         ConstantExpr::getIntegerCast(KnownSize, GroupSize->getType(), false));
303     MadeChange = true;
304   }
305 
306   return MadeChange;
307 }
308 
309 
310 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
311 // TargetPassConfig for subtarget.
312 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
313   bool MadeChange = false;
314   Function *BasePtr = nullptr;
315   bool IsV5OrAbove = AMDGPU::getAmdhsaCodeObjectVersion() >= 5;
316   if (IsV5OrAbove) {
317     StringRef ImplicitArgPtrName =
318         Intrinsic::getName(Intrinsic::amdgcn_implicitarg_ptr);
319     BasePtr = M.getFunction(ImplicitArgPtrName);
320   } else { // Pre-V5.
321     StringRef DispatchPtrName =
322         Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
323     BasePtr = M.getFunction(DispatchPtrName);
324   }
325 
326   if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
327     return false;
328 
329   SmallPtrSet<Instruction *, 4> HandledUses;
330   for (auto *U : BasePtr->users()) {
331     CallInst *CI = cast<CallInst>(U);
332     if (HandledUses.insert(CI).second) {
333       if (processUse(CI, IsV5OrAbove))
334         MadeChange = true;
335     }
336   }
337 
338   return MadeChange;
339 }
340 
341 
342 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
343                       "AMDGPU Kernel Attributes", false, false)
344 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
345                     "AMDGPU Kernel Attributes", false, false)
346 
347 char AMDGPULowerKernelAttributes::ID = 0;
348 
349 ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
350   return new AMDGPULowerKernelAttributes();
351 }
352 
353 PreservedAnalyses
354 AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) {
355   Function *BasePtr = nullptr;
356   bool IsV5OrAbove = AMDGPU::getAmdhsaCodeObjectVersion() >= 5;
357   if (IsV5OrAbove) {
358     StringRef ImplicitArgPtrName =
359         Intrinsic::getName(Intrinsic::amdgcn_implicitarg_ptr);
360     BasePtr = F.getParent()->getFunction(ImplicitArgPtrName);
361   } else { // Pre_V5.
362     StringRef DispatchPtrName =
363         Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
364     BasePtr = F.getParent()->getFunction(DispatchPtrName);
365   }
366 
367   if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
368     return PreservedAnalyses::all();
369 
370   for (Instruction &I : instructions(F)) {
371     if (CallInst *CI = dyn_cast<CallInst>(&I)) {
372       if (CI->getCalledFunction() == BasePtr)
373         processUse(CI, IsV5OrAbove);
374     }
375   }
376 
377   return PreservedAnalyses::all();
378 }
379