xref: /llvm-project/offload/DeviceRTL/src/Mapping.cpp (revision f53cb84df6b80458cb4d5ab7398a590356a3a952)
1 //===------- Mapping.cpp - OpenMP device runtime mapping helpers -- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #include "Mapping.h"
13 #include "DeviceTypes.h"
14 #include "DeviceUtils.h"
15 #include "Interface.h"
16 #include "State.h"
17 
18 #pragma omp begin declare target device_type(nohost)
19 
20 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
21 
22 using namespace ompx;
23 
24 namespace ompx {
25 namespace impl {
26 
27 // Forward declarations defined to be defined for AMDGCN and NVPTX.
28 LaneMaskTy activemask();
29 LaneMaskTy lanemaskLT();
30 LaneMaskTy lanemaskGT();
31 uint32_t getThreadIdInWarp();
32 uint32_t getThreadIdInBlock(int32_t Dim);
33 uint32_t getNumberOfThreadsInBlock(int32_t Dim);
34 uint32_t getNumberOfThreadsInKernel();
35 uint32_t getBlockIdInKernel(int32_t Dim);
36 uint32_t getNumberOfBlocksInKernel(int32_t Dim);
37 uint32_t getWarpIdInBlock();
38 uint32_t getNumberOfWarpsInBlock();
39 uint32_t getWarpSize();
40 
41 /// AMDGCN Implementation
42 ///
43 ///{
44 #pragma omp begin declare variant match(device = {arch(amdgcn)})
45 
46 uint32_t getWarpSize() { return __builtin_amdgcn_wavefrontsize(); }
47 
48 uint32_t getNumberOfThreadsInBlock(int32_t Dim) {
49   switch (Dim) {
50   case 0:
51     return __builtin_amdgcn_workgroup_size_x();
52   case 1:
53     return __builtin_amdgcn_workgroup_size_y();
54   case 2:
55     return __builtin_amdgcn_workgroup_size_z();
56   };
57   UNREACHABLE("Dim outside range!");
58 }
59 
60 LaneMaskTy activemask() { return __builtin_amdgcn_read_exec(); }
61 
62 LaneMaskTy lanemaskLT() {
63   uint32_t Lane = mapping::getThreadIdInWarp();
64   int64_t Ballot = mapping::activemask();
65   uint64_t Mask = ((uint64_t)1 << Lane) - (uint64_t)1;
66   return Mask & Ballot;
67 }
68 
69 LaneMaskTy lanemaskGT() {
70   uint32_t Lane = mapping::getThreadIdInWarp();
71   if (Lane == (mapping::getWarpSize() - 1))
72     return 0;
73   int64_t Ballot = mapping::activemask();
74   uint64_t Mask = (~((uint64_t)0)) << (Lane + 1);
75   return Mask & Ballot;
76 }
77 
78 uint32_t getThreadIdInWarp() {
79   return __builtin_amdgcn_mbcnt_hi(~0u, __builtin_amdgcn_mbcnt_lo(~0u, 0u));
80 }
81 
82 uint32_t getThreadIdInBlock(int32_t Dim) {
83   switch (Dim) {
84   case 0:
85     return __builtin_amdgcn_workitem_id_x();
86   case 1:
87     return __builtin_amdgcn_workitem_id_y();
88   case 2:
89     return __builtin_amdgcn_workitem_id_z();
90   };
91   UNREACHABLE("Dim outside range!");
92 }
93 
94 uint32_t getNumberOfThreadsInKernel() {
95   return __builtin_amdgcn_grid_size_x() * __builtin_amdgcn_grid_size_y() *
96          __builtin_amdgcn_grid_size_z();
97 }
98 
99 uint32_t getBlockIdInKernel(int32_t Dim) {
100   switch (Dim) {
101   case 0:
102     return __builtin_amdgcn_workgroup_id_x();
103   case 1:
104     return __builtin_amdgcn_workgroup_id_y();
105   case 2:
106     return __builtin_amdgcn_workgroup_id_z();
107   };
108   UNREACHABLE("Dim outside range!");
109 }
110 
111 uint32_t getNumberOfBlocksInKernel(int32_t Dim) {
112   switch (Dim) {
113   case 0:
114     return __builtin_amdgcn_grid_size_x() / __builtin_amdgcn_workgroup_size_x();
115   case 1:
116     return __builtin_amdgcn_grid_size_y() / __builtin_amdgcn_workgroup_size_y();
117   case 2:
118     return __builtin_amdgcn_grid_size_z() / __builtin_amdgcn_workgroup_size_z();
119   };
120   UNREACHABLE("Dim outside range!");
121 }
122 
123 uint32_t getWarpIdInBlock() {
124   return impl::getThreadIdInBlock(mapping::DIM_X) / mapping::getWarpSize();
125 }
126 
127 uint32_t getNumberOfWarpsInBlock() {
128   return mapping::getNumberOfThreadsInBlock() / mapping::getWarpSize();
129 }
130 
131 #pragma omp end declare variant
132 ///}
133 
134 /// NVPTX Implementation
135 ///
136 ///{
137 #pragma omp begin declare variant match(                                       \
138         device = {arch(nvptx, nvptx64)},                                       \
139             implementation = {extension(match_any)})
140 
141 uint32_t getNumberOfThreadsInBlock(int32_t Dim) {
142   switch (Dim) {
143   case 0:
144     return __nvvm_read_ptx_sreg_ntid_x();
145   case 1:
146     return __nvvm_read_ptx_sreg_ntid_y();
147   case 2:
148     return __nvvm_read_ptx_sreg_ntid_z();
149   };
150   UNREACHABLE("Dim outside range!");
151 }
152 
153 uint32_t getWarpSize() { return __nvvm_read_ptx_sreg_warpsize(); }
154 
155 LaneMaskTy activemask() { return __nvvm_activemask(); }
156 
157 LaneMaskTy lanemaskLT() { return __nvvm_read_ptx_sreg_lanemask_lt(); }
158 
159 LaneMaskTy lanemaskGT() { return __nvvm_read_ptx_sreg_lanemask_gt(); }
160 
161 uint32_t getThreadIdInBlock(int32_t Dim) {
162   switch (Dim) {
163   case 0:
164     return __nvvm_read_ptx_sreg_tid_x();
165   case 1:
166     return __nvvm_read_ptx_sreg_tid_y();
167   case 2:
168     return __nvvm_read_ptx_sreg_tid_z();
169   };
170   UNREACHABLE("Dim outside range!");
171 }
172 
173 uint32_t getThreadIdInWarp() { return __nvvm_read_ptx_sreg_laneid(); }
174 
175 uint32_t getBlockIdInKernel(int32_t Dim) {
176   switch (Dim) {
177   case 0:
178     return __nvvm_read_ptx_sreg_ctaid_x();
179   case 1:
180     return __nvvm_read_ptx_sreg_ctaid_y();
181   case 2:
182     return __nvvm_read_ptx_sreg_ctaid_z();
183   };
184   UNREACHABLE("Dim outside range!");
185 }
186 
187 uint32_t getNumberOfBlocksInKernel(int32_t Dim) {
188   switch (Dim) {
189   case 0:
190     return __nvvm_read_ptx_sreg_nctaid_x();
191   case 1:
192     return __nvvm_read_ptx_sreg_nctaid_y();
193   case 2:
194     return __nvvm_read_ptx_sreg_nctaid_z();
195   };
196   UNREACHABLE("Dim outside range!");
197 }
198 
199 uint32_t getNumberOfThreadsInKernel() {
200   return impl::getNumberOfThreadsInBlock(0) *
201          impl::getNumberOfBlocksInKernel(0) *
202          impl::getNumberOfThreadsInBlock(1) *
203          impl::getNumberOfBlocksInKernel(1) *
204          impl::getNumberOfThreadsInBlock(2) *
205          impl::getNumberOfBlocksInKernel(2);
206 }
207 
208 uint32_t getWarpIdInBlock() {
209   return impl::getThreadIdInBlock(mapping::DIM_X) / mapping::getWarpSize();
210 }
211 
212 uint32_t getNumberOfWarpsInBlock() {
213   return (mapping::getNumberOfThreadsInBlock() + mapping::getWarpSize() - 1) /
214          mapping::getWarpSize();
215 }
216 
217 #pragma omp end declare variant
218 ///}
219 
220 } // namespace impl
221 } // namespace ompx
222 
223 /// We have to be deliberate about the distinction of `mapping::` and `impl::`
224 /// below to avoid repeating assumptions or including irrelevant ones.
225 ///{
226 
227 static bool isInLastWarp() {
228   uint32_t MainTId = (mapping::getNumberOfThreadsInBlock() - 1) &
229                      ~(mapping::getWarpSize() - 1);
230   return mapping::getThreadIdInBlock() == MainTId;
231 }
232 
233 bool mapping::isMainThreadInGenericMode(bool IsSPMD) {
234   if (IsSPMD || icv::Level)
235     return false;
236 
237   // Check if this is the last warp in the block.
238   return isInLastWarp();
239 }
240 
241 bool mapping::isMainThreadInGenericMode() {
242   return mapping::isMainThreadInGenericMode(mapping::isSPMDMode());
243 }
244 
245 bool mapping::isInitialThreadInLevel0(bool IsSPMD) {
246   if (IsSPMD)
247     return mapping::getThreadIdInBlock() == 0;
248   return isInLastWarp();
249 }
250 
251 bool mapping::isLeaderInWarp() {
252   __kmpc_impl_lanemask_t Active = mapping::activemask();
253   __kmpc_impl_lanemask_t LaneMaskLT = mapping::lanemaskLT();
254   return utils::popc(Active & LaneMaskLT) == 0;
255 }
256 
257 LaneMaskTy mapping::activemask() { return impl::activemask(); }
258 
259 LaneMaskTy mapping::lanemaskLT() { return impl::lanemaskLT(); }
260 
261 LaneMaskTy mapping::lanemaskGT() { return impl::lanemaskGT(); }
262 
263 uint32_t mapping::getThreadIdInWarp() {
264   uint32_t ThreadIdInWarp = impl::getThreadIdInWarp();
265   ASSERT(ThreadIdInWarp < impl::getWarpSize(), nullptr);
266   return ThreadIdInWarp;
267 }
268 
269 uint32_t mapping::getThreadIdInBlock(int32_t Dim) {
270   uint32_t ThreadIdInBlock = impl::getThreadIdInBlock(Dim);
271   return ThreadIdInBlock;
272 }
273 
274 uint32_t mapping::getWarpSize() { return impl::getWarpSize(); }
275 
276 uint32_t mapping::getMaxTeamThreads(bool IsSPMD) {
277   uint32_t BlockSize = mapping::getNumberOfThreadsInBlock();
278   // If we are in SPMD mode, remove one warp.
279   return BlockSize - (!IsSPMD * impl::getWarpSize());
280 }
281 uint32_t mapping::getMaxTeamThreads() {
282   return mapping::getMaxTeamThreads(mapping::isSPMDMode());
283 }
284 
285 uint32_t mapping::getNumberOfThreadsInBlock(int32_t Dim) {
286   return impl::getNumberOfThreadsInBlock(Dim);
287 }
288 
289 uint32_t mapping::getNumberOfThreadsInKernel() {
290   return impl::getNumberOfThreadsInKernel();
291 }
292 
293 uint32_t mapping::getWarpIdInBlock() {
294   uint32_t WarpID = impl::getWarpIdInBlock();
295   ASSERT(WarpID < impl::getNumberOfWarpsInBlock(), nullptr);
296   return WarpID;
297 }
298 
299 uint32_t mapping::getBlockIdInKernel(int32_t Dim) {
300   uint32_t BlockId = impl::getBlockIdInKernel(Dim);
301   ASSERT(BlockId < impl::getNumberOfBlocksInKernel(Dim), nullptr);
302   return BlockId;
303 }
304 
305 uint32_t mapping::getNumberOfWarpsInBlock() {
306   uint32_t NumberOfWarpsInBlocks = impl::getNumberOfWarpsInBlock();
307   ASSERT(impl::getWarpIdInBlock() < NumberOfWarpsInBlocks, nullptr);
308   return NumberOfWarpsInBlocks;
309 }
310 
311 uint32_t mapping::getNumberOfBlocksInKernel(int32_t Dim) {
312   uint32_t NumberOfBlocks = impl::getNumberOfBlocksInKernel(Dim);
313   ASSERT(impl::getBlockIdInKernel(Dim) < NumberOfBlocks, nullptr);
314   return NumberOfBlocks;
315 }
316 
317 uint32_t mapping::getNumberOfProcessorElements() {
318   return static_cast<uint32_t>(config::getHardwareParallelism());
319 }
320 
321 ///}
322 
323 /// Execution mode
324 ///
325 ///{
326 
327 // TODO: This is a workaround for initialization coming from kernels outside of
328 //       the TU. We will need to solve this more correctly in the future.
329 [[gnu::weak]] int SHARED(IsSPMDMode);
330 
331 void mapping::init(bool IsSPMD) {
332   if (mapping::isInitialThreadInLevel0(IsSPMD))
333     IsSPMDMode = IsSPMD;
334 }
335 
336 bool mapping::isSPMDMode() { return IsSPMDMode; }
337 
338 bool mapping::isGenericMode() { return !isSPMDMode(); }
339 ///}
340 
341 extern "C" {
342 [[gnu::noinline]] uint32_t __kmpc_get_hardware_thread_id_in_block() {
343   return mapping::getThreadIdInBlock();
344 }
345 
346 [[gnu::noinline]] uint32_t __kmpc_get_hardware_num_threads_in_block() {
347   return impl::getNumberOfThreadsInBlock(mapping::DIM_X);
348 }
349 
350 [[gnu::noinline]] uint32_t __kmpc_get_warp_size() {
351   return impl::getWarpSize();
352 }
353 }
354 
355 #define _TGT_KERNEL_LANGUAGE(NAME, MAPPER_NAME)                                \
356   extern "C" int ompx_##NAME(int Dim) { return mapping::MAPPER_NAME(Dim); }
357 
358 _TGT_KERNEL_LANGUAGE(thread_id, getThreadIdInBlock)
359 _TGT_KERNEL_LANGUAGE(block_id, getBlockIdInKernel)
360 _TGT_KERNEL_LANGUAGE(block_dim, getNumberOfThreadsInBlock)
361 _TGT_KERNEL_LANGUAGE(grid_dim, getNumberOfBlocksInKernel)
362 
363 extern "C" {
364 uint64_t ompx_ballot_sync(uint64_t mask, int pred) {
365   return utils::ballotSync(mask, pred);
366 }
367 
368 int ompx_shfl_down_sync_i(uint64_t mask, int var, unsigned delta, int width) {
369   return utils::shuffleDown(mask, var, delta, width);
370 }
371 
372 float ompx_shfl_down_sync_f(uint64_t mask, float var, unsigned delta,
373                             int width) {
374   return utils::bitCast<float>(
375       utils::shuffleDown(mask, utils::bitCast<int32_t>(var), delta, width));
376 }
377 
378 long ompx_shfl_down_sync_l(uint64_t mask, long var, unsigned delta, int width) {
379   return utils::shuffleDown(mask, var, delta, width);
380 }
381 
382 double ompx_shfl_down_sync_d(uint64_t mask, double var, unsigned delta,
383                              int width) {
384   return utils::bitCast<double>(
385       utils::shuffleDown(mask, utils::bitCast<int64_t>(var), delta, width));
386 }
387 }
388 
389 #pragma omp end declare target
390