xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp (revision 62e9f40949ddc52e9660b25ab146bd5d9b39ad88)
1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This pass does attempts to make use of reqd_work_group_size metadata
10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11 /// get_local_size-like functions.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "Utils/AMDGPUBaseInfo.h"
17 #include "llvm/Analysis/ConstantFolding.h"
18 #include "llvm/Analysis/ValueTracking.h"
19 #include "llvm/CodeGen/Passes.h"
20 #include "llvm/CodeGen/TargetPassConfig.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/Function.h"
23 #include "llvm/IR/InstIterator.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/IntrinsicsAMDGPU.h"
26 #include "llvm/IR/PatternMatch.h"
27 #include "llvm/Pass.h"
28 
29 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
30 
31 using namespace llvm;
32 
33 namespace {
34 
35 // Field offsets in hsa_kernel_dispatch_packet_t.
36 enum DispatchPackedOffsets {
37   WORKGROUP_SIZE_X = 4,
38   WORKGROUP_SIZE_Y = 6,
39   WORKGROUP_SIZE_Z = 8,
40 
41   GRID_SIZE_X = 12,
42   GRID_SIZE_Y = 16,
43   GRID_SIZE_Z = 20
44 };
45 
46 // Field offsets to implicit kernel argument pointer.
47 enum ImplicitArgOffsets {
48   HIDDEN_BLOCK_COUNT_X = 0,
49   HIDDEN_BLOCK_COUNT_Y = 4,
50   HIDDEN_BLOCK_COUNT_Z = 8,
51 
52   HIDDEN_GROUP_SIZE_X = 12,
53   HIDDEN_GROUP_SIZE_Y = 14,
54   HIDDEN_GROUP_SIZE_Z = 16,
55 
56   HIDDEN_REMAINDER_X = 18,
57   HIDDEN_REMAINDER_Y = 20,
58   HIDDEN_REMAINDER_Z = 22,
59 };
60 
61 class AMDGPULowerKernelAttributes : public ModulePass {
62 public:
63   static char ID;
64 
65   AMDGPULowerKernelAttributes() : ModulePass(ID) {}
66 
67   bool runOnModule(Module &M) override;
68 
69   StringRef getPassName() const override {
70     return "AMDGPU Kernel Attributes";
71   }
72 
73   void getAnalysisUsage(AnalysisUsage &AU) const override {
74     AU.setPreservesAll();
75  }
76 };
77 
78 Function *getBasePtrIntrinsic(Module &M, bool IsV5OrAbove) {
79   auto IntrinsicId = IsV5OrAbove ? Intrinsic::amdgcn_implicitarg_ptr
80                                  : Intrinsic::amdgcn_dispatch_ptr;
81   StringRef Name = Intrinsic::getName(IntrinsicId);
82   return M.getFunction(Name);
83 }
84 
85 } // end anonymous namespace
86 
87 static bool processUse(CallInst *CI, bool IsV5OrAbove) {
88   Function *F = CI->getParent()->getParent();
89 
90   auto MD = F->getMetadata("reqd_work_group_size");
91   const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
92 
93   const bool HasUniformWorkGroupSize =
94     F->getFnAttribute("uniform-work-group-size").getValueAsBool();
95 
96   if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
97     return false;
98 
99   Value *BlockCounts[3] = {nullptr, nullptr, nullptr};
100   Value *GroupSizes[3]  = {nullptr, nullptr, nullptr};
101   Value *Remainders[3]  = {nullptr, nullptr, nullptr};
102   Value *GridSizes[3]   = {nullptr, nullptr, nullptr};
103 
104   const DataLayout &DL = F->getDataLayout();
105 
106   // We expect to see several GEP users, casted to the appropriate type and
107   // loaded.
108   for (User *U : CI->users()) {
109     if (!U->hasOneUse())
110       continue;
111 
112     int64_t Offset = 0;
113     auto *Load = dyn_cast<LoadInst>(U); // Load from ImplicitArgPtr/DispatchPtr?
114     auto *BCI = dyn_cast<BitCastInst>(U);
115     if (!Load && !BCI) {
116       if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
117         continue;
118       Load = dyn_cast<LoadInst>(*U->user_begin()); // Load from GEP?
119       BCI = dyn_cast<BitCastInst>(*U->user_begin());
120     }
121 
122     if (BCI) {
123       if (!BCI->hasOneUse())
124         continue;
125       Load = dyn_cast<LoadInst>(*BCI->user_begin()); // Load from BCI?
126     }
127 
128     if (!Load || !Load->isSimple())
129       continue;
130 
131     unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
132 
133     // TODO: Handle merged loads.
134     if (IsV5OrAbove) { // Base is ImplicitArgPtr.
135       switch (Offset) {
136       case HIDDEN_BLOCK_COUNT_X:
137         if (LoadSize == 4)
138           BlockCounts[0] = Load;
139         break;
140       case HIDDEN_BLOCK_COUNT_Y:
141         if (LoadSize == 4)
142           BlockCounts[1] = Load;
143         break;
144       case HIDDEN_BLOCK_COUNT_Z:
145         if (LoadSize == 4)
146           BlockCounts[2] = Load;
147         break;
148       case HIDDEN_GROUP_SIZE_X:
149         if (LoadSize == 2)
150           GroupSizes[0] = Load;
151         break;
152       case HIDDEN_GROUP_SIZE_Y:
153         if (LoadSize == 2)
154           GroupSizes[1] = Load;
155         break;
156       case HIDDEN_GROUP_SIZE_Z:
157         if (LoadSize == 2)
158           GroupSizes[2] = Load;
159         break;
160       case HIDDEN_REMAINDER_X:
161         if (LoadSize == 2)
162           Remainders[0] = Load;
163         break;
164       case HIDDEN_REMAINDER_Y:
165         if (LoadSize == 2)
166           Remainders[1] = Load;
167         break;
168       case HIDDEN_REMAINDER_Z:
169         if (LoadSize == 2)
170           Remainders[2] = Load;
171         break;
172       default:
173         break;
174       }
175     } else { // Base is DispatchPtr.
176       switch (Offset) {
177       case WORKGROUP_SIZE_X:
178         if (LoadSize == 2)
179           GroupSizes[0] = Load;
180         break;
181       case WORKGROUP_SIZE_Y:
182         if (LoadSize == 2)
183           GroupSizes[1] = Load;
184         break;
185       case WORKGROUP_SIZE_Z:
186         if (LoadSize == 2)
187           GroupSizes[2] = Load;
188         break;
189       case GRID_SIZE_X:
190         if (LoadSize == 4)
191           GridSizes[0] = Load;
192         break;
193       case GRID_SIZE_Y:
194         if (LoadSize == 4)
195           GridSizes[1] = Load;
196         break;
197       case GRID_SIZE_Z:
198         if (LoadSize == 4)
199           GridSizes[2] = Load;
200         break;
201       default:
202         break;
203       }
204     }
205   }
206 
207   bool MadeChange = false;
208   if (IsV5OrAbove && HasUniformWorkGroupSize) {
209     // Under v5  __ockl_get_local_size returns the value computed by the expression:
210     //
211     //   workgroup_id < hidden_block_count ? hidden_group_size : hidden_remainder
212     //
213     // For functions with the attribute uniform-work-group-size=true. we can evaluate
214     // workgroup_id < hidden_block_count as true, and thus hidden_group_size is returned
215     // for __ockl_get_local_size.
216     for (int I = 0; I < 3; ++I) {
217       Value *BlockCount = BlockCounts[I];
218       if (!BlockCount)
219         continue;
220 
221       using namespace llvm::PatternMatch;
222       auto GroupIDIntrin =
223           I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
224                  : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
225                            : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
226 
227       for (User *ICmp : BlockCount->users()) {
228         if (match(ICmp, m_SpecificICmp(ICmpInst::ICMP_ULT, GroupIDIntrin,
229                                        m_Specific(BlockCount)))) {
230           ICmp->replaceAllUsesWith(llvm::ConstantInt::getTrue(ICmp->getType()));
231           MadeChange = true;
232         }
233       }
234     }
235 
236     // All remainders should be 0 with uniform work group size.
237     for (Value *Remainder : Remainders) {
238       if (!Remainder)
239         continue;
240       Remainder->replaceAllUsesWith(Constant::getNullValue(Remainder->getType()));
241       MadeChange = true;
242     }
243   } else if (HasUniformWorkGroupSize) { // Pre-V5.
244     // Pattern match the code used to handle partial workgroup dispatches in the
245     // library implementation of get_local_size, so the entire function can be
246     // constant folded with a known group size.
247     //
248     // uint r = grid_size - group_id * group_size;
249     // get_local_size = (r < group_size) ? r : group_size;
250     //
251     // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
252     // the grid_size is required to be a multiple of group_size). In this case:
253     //
254     // grid_size - (group_id * group_size) < group_size
255     // ->
256     // grid_size < group_size + (group_id * group_size)
257     //
258     // (grid_size / group_size) < 1 + group_id
259     //
260     // grid_size / group_size is at least 1, so we can conclude the select
261     // condition is false (except for group_id == 0, where the select result is
262     // the same).
263     for (int I = 0; I < 3; ++I) {
264       Value *GroupSize = GroupSizes[I];
265       Value *GridSize = GridSizes[I];
266       if (!GroupSize || !GridSize)
267         continue;
268 
269       using namespace llvm::PatternMatch;
270       auto GroupIDIntrin =
271           I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
272                  : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
273                            : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
274 
275       for (User *U : GroupSize->users()) {
276         auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
277         if (!ZextGroupSize)
278           continue;
279 
280         for (User *UMin : ZextGroupSize->users()) {
281           if (match(UMin,
282                     m_UMin(m_Sub(m_Specific(GridSize),
283                                  m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))),
284                            m_Specific(ZextGroupSize)))) {
285             if (HasReqdWorkGroupSize) {
286               ConstantInt *KnownSize
287                 = mdconst::extract<ConstantInt>(MD->getOperand(I));
288               UMin->replaceAllUsesWith(ConstantFoldIntegerCast(
289                   KnownSize, UMin->getType(), false, DL));
290             } else {
291               UMin->replaceAllUsesWith(ZextGroupSize);
292             }
293 
294             MadeChange = true;
295           }
296         }
297       }
298     }
299   }
300 
301   // If reqd_work_group_size is set, we can replace work group size with it.
302   if (!HasReqdWorkGroupSize)
303     return MadeChange;
304 
305   for (int I = 0; I < 3; I++) {
306     Value *GroupSize = GroupSizes[I];
307     if (!GroupSize)
308       continue;
309 
310     ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
311     GroupSize->replaceAllUsesWith(
312         ConstantFoldIntegerCast(KnownSize, GroupSize->getType(), false, DL));
313     MadeChange = true;
314   }
315 
316   return MadeChange;
317 }
318 
319 
320 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
321 // TargetPassConfig for subtarget.
322 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
323   bool MadeChange = false;
324   bool IsV5OrAbove =
325       AMDGPU::getAMDHSACodeObjectVersion(M) >= AMDGPU::AMDHSA_COV5;
326   Function *BasePtr = getBasePtrIntrinsic(M, IsV5OrAbove);
327 
328   if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
329     return false;
330 
331   SmallPtrSet<Instruction *, 4> HandledUses;
332   for (auto *U : BasePtr->users()) {
333     CallInst *CI = cast<CallInst>(U);
334     if (HandledUses.insert(CI).second) {
335       if (processUse(CI, IsV5OrAbove))
336         MadeChange = true;
337     }
338   }
339 
340   return MadeChange;
341 }
342 
343 
344 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
345                       "AMDGPU Kernel Attributes", false, false)
346 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
347                     "AMDGPU Kernel Attributes", false, false)
348 
349 char AMDGPULowerKernelAttributes::ID = 0;
350 
351 ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
352   return new AMDGPULowerKernelAttributes();
353 }
354 
355 PreservedAnalyses
356 AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) {
357   bool IsV5OrAbove =
358       AMDGPU::getAMDHSACodeObjectVersion(*F.getParent()) >= AMDGPU::AMDHSA_COV5;
359   Function *BasePtr = getBasePtrIntrinsic(*F.getParent(), IsV5OrAbove);
360 
361   if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
362     return PreservedAnalyses::all();
363 
364   for (Instruction &I : instructions(F)) {
365     if (CallInst *CI = dyn_cast<CallInst>(&I)) {
366       if (CI->getCalledFunction() == BasePtr)
367         processUse(CI, IsV5OrAbove);
368     }
369   }
370 
371   return PreservedAnalyses::all();
372 }
373