xref: /llvm-project/llvm/lib/Target/AMDGPU/AMDGPULowerKernelAttributes.cpp (revision dee4bc4a4ecc56623d511ea571355d1e1ad02159)
1 //===-- AMDGPULowerKernelAttributes.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file This pass does attempts to make use of reqd_work_group_size metadata
10 /// to eliminate loads from the dispatch packet and to constant fold OpenCL
11 /// get_local_size-like functions.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "AMDGPU.h"
16 #include "Utils/AMDGPUBaseInfo.h"
17 #include "llvm/Analysis/ValueTracking.h"
18 #include "llvm/CodeGen/Passes.h"
19 #include "llvm/CodeGen/TargetPassConfig.h"
20 #include "llvm/IR/Constants.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/InstIterator.h"
23 #include "llvm/IR/Instructions.h"
24 #include "llvm/IR/IntrinsicsAMDGPU.h"
25 #include "llvm/IR/PatternMatch.h"
26 #include "llvm/Pass.h"
27 
28 #define DEBUG_TYPE "amdgpu-lower-kernel-attributes"
29 
30 using namespace llvm;
31 
32 namespace {
33 
34 // Field offsets in hsa_kernel_dispatch_packet_t.
35 enum DispatchPackedOffsets {
36   WORKGROUP_SIZE_X = 4,
37   WORKGROUP_SIZE_Y = 6,
38   WORKGROUP_SIZE_Z = 8,
39 
40   GRID_SIZE_X = 12,
41   GRID_SIZE_Y = 16,
42   GRID_SIZE_Z = 20
43 };
44 
45 // Field offsets to implicit kernel argument pointer.
46 enum ImplicitArgOffsets {
47   HIDDEN_BLOCK_COUNT_X = 0,
48   HIDDEN_BLOCK_COUNT_Y = 4,
49   HIDDEN_BLOCK_COUNT_Z = 8,
50 
51   HIDDEN_GROUP_SIZE_X = 12,
52   HIDDEN_GROUP_SIZE_Y = 14,
53   HIDDEN_GROUP_SIZE_Z = 16,
54 
55   HIDDEN_REMAINDER_X = 18,
56   HIDDEN_REMAINDER_Y = 20,
57   HIDDEN_REMAINDER_Z = 22,
58 };
59 
60 class AMDGPULowerKernelAttributes : public ModulePass {
61 public:
62   static char ID;
63 
64   AMDGPULowerKernelAttributes() : ModulePass(ID) {}
65 
66   bool runOnModule(Module &M) override;
67 
68   StringRef getPassName() const override {
69     return "AMDGPU Kernel Attributes";
70   }
71 
72   void getAnalysisUsage(AnalysisUsage &AU) const override {
73     AU.setPreservesAll();
74  }
75 };
76 
77 } // end anonymous namespace
78 
79 static bool processUse(CallInst *CI, bool IsV5OrAbove) {
80   Function *F = CI->getParent()->getParent();
81 
82   auto MD = F->getMetadata("reqd_work_group_size");
83   const bool HasReqdWorkGroupSize = MD && MD->getNumOperands() == 3;
84 
85   const bool HasUniformWorkGroupSize =
86     F->getFnAttribute("uniform-work-group-size").getValueAsBool();
87 
88   if (!HasReqdWorkGroupSize && !HasUniformWorkGroupSize)
89     return false;
90 
91   Value *BlockCounts[3] = {nullptr, nullptr, nullptr};
92   Value *GroupSizes[3]  = {nullptr, nullptr, nullptr};
93   Value *Remainders[3]  = {nullptr, nullptr, nullptr};
94   Value *GridSizes[3]   = {nullptr, nullptr, nullptr};
95 
96   const DataLayout &DL = F->getParent()->getDataLayout();
97 
98   // We expect to see several GEP users, casted to the appropriate type and
99   // loaded.
100   for (User *U : CI->users()) {
101     if (!U->hasOneUse())
102       continue;
103 
104     int64_t Offset = 0;
105     auto *Load = dyn_cast<LoadInst>(U); // Load from ImplicitArgPtr/DispatchPtr?
106     auto *BCI = dyn_cast<BitCastInst>(U);
107     if (!Load && !BCI) {
108       if (GetPointerBaseWithConstantOffset(U, Offset, DL) != CI)
109         continue;
110       Load = dyn_cast<LoadInst>(*U->user_begin()); // Load from GEP?
111       BCI = dyn_cast<BitCastInst>(*U->user_begin());
112     }
113 
114     if (BCI) {
115       if (!BCI->hasOneUse())
116         continue;
117       Load = dyn_cast<LoadInst>(*BCI->user_begin()); // Load from BCI?
118     }
119 
120     if (!Load || !Load->isSimple())
121       continue;
122 
123     unsigned LoadSize = DL.getTypeStoreSize(Load->getType());
124 
125     // TODO: Handle merged loads.
126     if (IsV5OrAbove) { // Base is ImplicitArgPtr.
127       switch (Offset) {
128       case HIDDEN_BLOCK_COUNT_X:
129         if (LoadSize == 4)
130           BlockCounts[0] = Load;
131         break;
132       case HIDDEN_BLOCK_COUNT_Y:
133         if (LoadSize == 4)
134           BlockCounts[1] = Load;
135         break;
136       case HIDDEN_BLOCK_COUNT_Z:
137         if (LoadSize == 4)
138           BlockCounts[2] = Load;
139         break;
140       case HIDDEN_GROUP_SIZE_X:
141         if (LoadSize == 2)
142           GroupSizes[0] = Load;
143         break;
144       case HIDDEN_GROUP_SIZE_Y:
145         if (LoadSize == 2)
146           GroupSizes[1] = Load;
147         break;
148       case HIDDEN_GROUP_SIZE_Z:
149         if (LoadSize == 2)
150           GroupSizes[2] = Load;
151         break;
152       case HIDDEN_REMAINDER_X:
153         if (LoadSize == 2)
154           Remainders[0] = Load;
155         break;
156       case HIDDEN_REMAINDER_Y:
157         if (LoadSize == 2)
158           Remainders[1] = Load;
159         break;
160       case HIDDEN_REMAINDER_Z:
161         if (LoadSize == 2)
162           Remainders[2] = Load;
163         break;
164       default:
165         break;
166       }
167     } else { // Base is DispatchPtr.
168       switch (Offset) {
169       case WORKGROUP_SIZE_X:
170         if (LoadSize == 2)
171           GroupSizes[0] = Load;
172         break;
173       case WORKGROUP_SIZE_Y:
174         if (LoadSize == 2)
175           GroupSizes[1] = Load;
176         break;
177       case WORKGROUP_SIZE_Z:
178         if (LoadSize == 2)
179           GroupSizes[2] = Load;
180         break;
181       case GRID_SIZE_X:
182         if (LoadSize == 4)
183           GridSizes[0] = Load;
184         break;
185       case GRID_SIZE_Y:
186         if (LoadSize == 4)
187           GridSizes[1] = Load;
188         break;
189       case GRID_SIZE_Z:
190         if (LoadSize == 4)
191           GridSizes[2] = Load;
192         break;
193       default:
194         break;
195       }
196     }
197   }
198 
199   bool MadeChange = false;
200   if (IsV5OrAbove && HasUniformWorkGroupSize) {
201     // Under v5  __ockl_get_local_size returns the value computed by the expression:
202     //
203     //   workgroup_id < hidden_block_count ? hidden_group_size : hidden_remainder
204     //
205     // For functions with the attribute uniform-work-group-size=true. we can evaluate
206     // workgroup_id < hidden_block_count as true, and thus hidden_group_size is returned
207     // for __ockl_get_local_size.
208     for (int I = 0; I < 3; ++I) {
209       Value *BlockCount = BlockCounts[I];
210       if (!BlockCount)
211         continue;
212 
213       using namespace llvm::PatternMatch;
214       auto GroupIDIntrin =
215           I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
216                  : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
217                            : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
218 
219       for (User *ICmp : BlockCount->users()) {
220         ICmpInst::Predicate Pred;
221         if (match(ICmp, m_ICmp(Pred, GroupIDIntrin, m_Specific(BlockCount)))) {
222           if (Pred != ICmpInst::ICMP_ULT)
223             continue;
224           ICmp->replaceAllUsesWith(llvm::ConstantInt::getTrue(ICmp->getType()));
225           MadeChange = true;
226         }
227       }
228     }
229 
230     // All remainders should be 0 with uniform work group size.
231     for (Value *Remainder : Remainders) {
232       if (!Remainder)
233         continue;
234       Remainder->replaceAllUsesWith(Constant::getNullValue(Remainder->getType()));
235       MadeChange = true;
236     }
237   } else if (HasUniformWorkGroupSize) { // Pre-V5.
238     // Pattern match the code used to handle partial workgroup dispatches in the
239     // library implementation of get_local_size, so the entire function can be
240     // constant folded with a known group size.
241     //
242     // uint r = grid_size - group_id * group_size;
243     // get_local_size = (r < group_size) ? r : group_size;
244     //
245     // If we have uniform-work-group-size (which is the default in OpenCL 1.2),
246     // the grid_size is required to be a multiple of group_size). In this case:
247     //
248     // grid_size - (group_id * group_size) < group_size
249     // ->
250     // grid_size < group_size + (group_id * group_size)
251     //
252     // (grid_size / group_size) < 1 + group_id
253     //
254     // grid_size / group_size is at least 1, so we can conclude the select
255     // condition is false (except for group_id == 0, where the select result is
256     // the same).
257     for (int I = 0; I < 3; ++I) {
258       Value *GroupSize = GroupSizes[I];
259       Value *GridSize = GridSizes[I];
260       if (!GroupSize || !GridSize)
261         continue;
262 
263       using namespace llvm::PatternMatch;
264       auto GroupIDIntrin =
265           I == 0 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_x>()
266                  : (I == 1 ? m_Intrinsic<Intrinsic::amdgcn_workgroup_id_y>()
267                            : m_Intrinsic<Intrinsic::amdgcn_workgroup_id_z>());
268 
269       for (User *U : GroupSize->users()) {
270         auto *ZextGroupSize = dyn_cast<ZExtInst>(U);
271         if (!ZextGroupSize)
272           continue;
273 
274         for (User *UMin : ZextGroupSize->users()) {
275           if (match(UMin,
276                     m_UMin(m_Sub(m_Specific(GridSize),
277                                  m_Mul(GroupIDIntrin, m_Specific(ZextGroupSize))),
278                            m_Specific(ZextGroupSize)))) {
279             if (HasReqdWorkGroupSize) {
280               ConstantInt *KnownSize
281                 = mdconst::extract<ConstantInt>(MD->getOperand(I));
282               UMin->replaceAllUsesWith(ConstantExpr::getIntegerCast(
283                   KnownSize, UMin->getType(), false));
284             } else {
285               UMin->replaceAllUsesWith(ZextGroupSize);
286             }
287 
288             MadeChange = true;
289           }
290         }
291       }
292     }
293   }
294 
295   // If reqd_work_group_size is set, we can replace work group size with it.
296   if (!HasReqdWorkGroupSize)
297     return MadeChange;
298 
299   for (int I = 0; I < 3; I++) {
300     Value *GroupSize = GroupSizes[I];
301     if (!GroupSize)
302       continue;
303 
304     ConstantInt *KnownSize = mdconst::extract<ConstantInt>(MD->getOperand(I));
305     GroupSize->replaceAllUsesWith(
306         ConstantExpr::getIntegerCast(KnownSize, GroupSize->getType(), false));
307     MadeChange = true;
308   }
309 
310   return MadeChange;
311 }
312 
313 
314 // TODO: Move makeLIDRangeMetadata usage into here. Seem to not get
315 // TargetPassConfig for subtarget.
316 bool AMDGPULowerKernelAttributes::runOnModule(Module &M) {
317   bool MadeChange = false;
318   Function *BasePtr = nullptr;
319   bool IsV5OrAbove = AMDGPU::getAmdhsaCodeObjectVersion() >= 5;
320   if (IsV5OrAbove) {
321     StringRef ImplicitArgPtrName =
322         Intrinsic::getName(Intrinsic::amdgcn_implicitarg_ptr);
323     BasePtr = M.getFunction(ImplicitArgPtrName);
324   } else { // Pre-V5.
325     StringRef DispatchPtrName =
326         Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
327     BasePtr = M.getFunction(DispatchPtrName);
328   }
329 
330   if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
331     return false;
332 
333   SmallPtrSet<Instruction *, 4> HandledUses;
334   for (auto *U : BasePtr->users()) {
335     CallInst *CI = cast<CallInst>(U);
336     if (HandledUses.insert(CI).second) {
337       if (processUse(CI, IsV5OrAbove))
338         MadeChange = true;
339     }
340   }
341 
342   return MadeChange;
343 }
344 
345 
346 INITIALIZE_PASS_BEGIN(AMDGPULowerKernelAttributes, DEBUG_TYPE,
347                       "AMDGPU Kernel Attributes", false, false)
348 INITIALIZE_PASS_END(AMDGPULowerKernelAttributes, DEBUG_TYPE,
349                     "AMDGPU Kernel Attributes", false, false)
350 
351 char AMDGPULowerKernelAttributes::ID = 0;
352 
353 ModulePass *llvm::createAMDGPULowerKernelAttributesPass() {
354   return new AMDGPULowerKernelAttributes();
355 }
356 
357 PreservedAnalyses
358 AMDGPULowerKernelAttributesPass::run(Function &F, FunctionAnalysisManager &AM) {
359   Function *BasePtr = nullptr;
360   bool IsV5OrAbove = AMDGPU::getAmdhsaCodeObjectVersion() >= 5;
361   if (IsV5OrAbove) {
362     StringRef ImplicitArgPtrName =
363         Intrinsic::getName(Intrinsic::amdgcn_implicitarg_ptr);
364     BasePtr = F.getParent()->getFunction(ImplicitArgPtrName);
365   } else { // Pre_V5.
366     StringRef DispatchPtrName =
367         Intrinsic::getName(Intrinsic::amdgcn_dispatch_ptr);
368     BasePtr = F.getParent()->getFunction(DispatchPtrName);
369   }
370 
371   if (!BasePtr) // ImplicitArgPtr/DispatchPtr not used.
372     return PreservedAnalyses::all();
373 
374   for (Instruction &I : instructions(F)) {
375     if (CallInst *CI = dyn_cast<CallInst>(&I)) {
376       if (CI->getCalledFunction() == BasePtr)
377         processUse(CI, IsV5OrAbove);
378     }
379   }
380 
381   return PreservedAnalyses::all();
382 }
383