xref: /llvm-project/offload/DeviceRTL/src/Workshare.cpp (revision 08533a3ee8f3a09a59cf6ac3be59198b26b7f739)
1 //===----- Workshare.cpp -  OpenMP workshare implementation ------ C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the KMPC interface
10 // for the loop construct plus other worksharing constructs that use the same
11 // interface as loops.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "Workshare.h"
16 #include "Debug.h"
17 #include "DeviceTypes.h"
18 #include "DeviceUtils.h"
19 #include "Interface.h"
20 #include "Mapping.h"
21 #include "State.h"
22 #include "Synchronization.h"
23 
24 using namespace ompx;
25 
26 // TODO:
27 struct DynamicScheduleTracker {
28   int64_t Chunk;
29   int64_t LoopUpperBound;
30   int64_t NextLowerBound;
31   int64_t Stride;
32   kmp_sched_t ScheduleType;
33   DynamicScheduleTracker *NextDST;
34 };
35 
36 #define ASSERT0(...)
37 
38 // used by the library for the interface with the app
39 #define DISPATCH_FINISHED 0
40 #define DISPATCH_NOTFINISHED 1
41 
42 // used by dynamic scheduling
43 #define FINISHED 0
44 #define NOT_FINISHED 1
45 #define LAST_CHUNK 2
46 
47 #pragma omp begin declare target device_type(nohost)
48 
49 // TODO: This variable is a hack inherited from the old runtime.
50 static uint64_t SHARED(Cnt);
51 
52 template <typename T, typename ST> struct omptarget_nvptx_LoopSupport {
53   ////////////////////////////////////////////////////////////////////////////////
54   // Loop with static scheduling with chunk
55 
56   // Generic implementation of OMP loop scheduling with static policy
57   /*! \brief Calculate initial bounds for static loop and stride
58    *  @param[in] loc location in code of the call (not used here)
59    *  @param[in] global_tid global thread id
60    *  @param[in] schetype type of scheduling (see omptarget-nvptx.h)
61    *  @param[in] plastiter pointer to last iteration
62    *  @param[in,out] pointer to loop lower bound. it will contain value of
63    *  lower bound of first chunk
64    *  @param[in,out] pointer to loop upper bound. It will contain value of
65    *  upper bound of first chunk
66    *  @param[in,out] pointer to loop stride. It will contain value of stride
67    *  between two successive chunks executed by the same thread
68    *  @param[in] loop increment bump
69    *  @param[in] chunk size
70    */
71 
72   // helper function for static chunk
73   static void ForStaticChunk(int &last, T &lb, T &ub, ST &stride, ST chunk,
74                              T entityId, T numberOfEntities) {
75     // each thread executes multiple chunks all of the same size, except
76     // the last one
77     // distance between two successive chunks
78     stride = numberOfEntities * chunk;
79     lb = lb + entityId * chunk;
80     T inputUb = ub;
81     ub = lb + chunk - 1; // Clang uses i <= ub
82     // Say ub' is the begining of the last chunk. Then who ever has a
83     // lower bound plus a multiple of the increment equal to ub' is
84     // the last one.
85     T beginingLastChunk = inputUb - (inputUb % chunk);
86     last = ((beginingLastChunk - lb) % stride) == 0;
87   }
88 
89   ////////////////////////////////////////////////////////////////////////////////
90   // Loop with static scheduling without chunk
91 
92   // helper function for static no chunk
93   static void ForStaticNoChunk(int &last, T &lb, T &ub, ST &stride, ST &chunk,
94                                T entityId, T numberOfEntities) {
95     // No chunk size specified.  Each thread or warp gets at most one
96     // chunk; chunks are all almost of equal size
97     T loopSize = ub - lb + 1;
98 
99     chunk = loopSize / numberOfEntities;
100     T leftOver = loopSize - chunk * numberOfEntities;
101 
102     if (entityId < leftOver) {
103       chunk++;
104       lb = lb + entityId * chunk;
105     } else {
106       lb = lb + entityId * chunk + leftOver;
107     }
108 
109     T inputUb = ub;
110     ub = lb + chunk - 1; // Clang uses i <= ub
111     last = lb <= inputUb && inputUb <= ub;
112     stride = loopSize; // make sure we only do 1 chunk per warp
113   }
114 
115   ////////////////////////////////////////////////////////////////////////////////
116   // Support for Static Init
117 
118   static void for_static_init(int32_t, int32_t schedtype, int32_t *plastiter,
119                               T *plower, T *pupper, ST *pstride, ST chunk,
120                               bool IsSPMDExecutionMode) {
121     int32_t gtid = omp_get_thread_num();
122     int numberOfActiveOMPThreads = omp_get_num_threads();
123 
124     // All warps that are in excess of the maximum requested, do
125     // not execute the loop
126     ASSERT0(LT_FUSSY, gtid < numberOfActiveOMPThreads,
127             "current thread is not needed here; error");
128 
129     // copy
130     int lastiter = 0;
131     T lb = *plower;
132     T ub = *pupper;
133     ST stride = *pstride;
134 
135     // init
136     switch (SCHEDULE_WITHOUT_MODIFIERS(schedtype)) {
137     case kmp_sched_static_chunk: {
138       if (chunk > 0) {
139         ForStaticChunk(lastiter, lb, ub, stride, chunk, gtid,
140                        numberOfActiveOMPThreads);
141         break;
142       }
143       [[fallthrough]];
144     } // note: if chunk <=0, use nochunk
145     case kmp_sched_static_balanced_chunk: {
146       if (chunk > 0) {
147         // round up to make sure the chunk is enough to cover all iterations
148         T tripCount = ub - lb + 1; // +1 because ub is inclusive
149         T span = (tripCount + numberOfActiveOMPThreads - 1) /
150                  numberOfActiveOMPThreads;
151         // perform chunk adjustment
152         chunk = (span + chunk - 1) & ~(chunk - 1);
153 
154         ASSERT0(LT_FUSSY, ub >= lb, "ub must be >= lb.");
155         T oldUb = ub;
156         ForStaticChunk(lastiter, lb, ub, stride, chunk, gtid,
157                        numberOfActiveOMPThreads);
158         if (ub > oldUb)
159           ub = oldUb;
160         break;
161       }
162       [[fallthrough]];
163     } // note: if chunk <=0, use nochunk
164     case kmp_sched_static_nochunk: {
165       ForStaticNoChunk(lastiter, lb, ub, stride, chunk, gtid,
166                        numberOfActiveOMPThreads);
167       break;
168     }
169     case kmp_sched_distr_static_chunk: {
170       if (chunk > 0) {
171         ForStaticChunk(lastiter, lb, ub, stride, chunk, omp_get_team_num(),
172                        omp_get_num_teams());
173         break;
174       }
175       [[fallthrough]];
176     } // note: if chunk <=0, use nochunk
177     case kmp_sched_distr_static_nochunk: {
178       ForStaticNoChunk(lastiter, lb, ub, stride, chunk, omp_get_team_num(),
179                        omp_get_num_teams());
180       break;
181     }
182     case kmp_sched_distr_static_chunk_sched_static_chunkone: {
183       ForStaticChunk(lastiter, lb, ub, stride, chunk,
184                      numberOfActiveOMPThreads * omp_get_team_num() + gtid,
185                      omp_get_num_teams() * numberOfActiveOMPThreads);
186       break;
187     }
188     default: {
189       // ASSERT(LT_FUSSY, 0, "unknown schedtype %d", (int)schedtype);
190       ForStaticChunk(lastiter, lb, ub, stride, chunk, gtid,
191                      numberOfActiveOMPThreads);
192       break;
193     }
194     }
195     // copy back
196     *plastiter = lastiter;
197     *plower = lb;
198     *pupper = ub;
199     *pstride = stride;
200   }
201 
202   ////////////////////////////////////////////////////////////////////////////////
203   // Support for dispatch Init
204 
205   static int OrderedSchedule(kmp_sched_t schedule) {
206     return schedule >= kmp_sched_ordered_first &&
207            schedule <= kmp_sched_ordered_last;
208   }
209 
210   static void dispatch_init(IdentTy *loc, int32_t threadId,
211                             kmp_sched_t schedule, T lb, T ub, ST st, ST chunk,
212                             DynamicScheduleTracker *DST) {
213     int tid = mapping::getThreadIdInBlock();
214     T tnum = omp_get_num_threads();
215     T tripCount = ub - lb + 1; // +1 because ub is inclusive
216     ASSERT0(LT_FUSSY, threadId < tnum,
217             "current thread is not needed here; error");
218 
219     /* Currently just ignore the monotonic and non-monotonic modifiers
220      * (the compiler isn't producing them * yet anyway).
221      * When it is we'll want to look at them somewhere here and use that
222      * information to add to our schedule choice. We shouldn't need to pass
223      * them on, they merely affect which schedule we can legally choose for
224      * various dynamic cases. (In particular, whether or not a stealing scheme
225      * is legal).
226      */
227     schedule = SCHEDULE_WITHOUT_MODIFIERS(schedule);
228 
229     // Process schedule.
230     if (tnum == 1 || tripCount <= 1 || OrderedSchedule(schedule)) {
231       if (OrderedSchedule(schedule))
232         __kmpc_barrier(loc, threadId);
233       schedule = kmp_sched_static_chunk;
234       chunk = tripCount; // one thread gets the whole loop
235     } else if (schedule == kmp_sched_runtime) {
236       // process runtime
237       omp_sched_t rtSched;
238       int ChunkInt;
239       omp_get_schedule(&rtSched, &ChunkInt);
240       chunk = ChunkInt;
241       switch (rtSched) {
242       case omp_sched_static: {
243         if (chunk > 0)
244           schedule = kmp_sched_static_chunk;
245         else
246           schedule = kmp_sched_static_nochunk;
247         break;
248       }
249       case omp_sched_auto: {
250         schedule = kmp_sched_static_chunk;
251         chunk = 1;
252         break;
253       }
254       case omp_sched_dynamic:
255       case omp_sched_guided: {
256         schedule = kmp_sched_dynamic;
257         break;
258       }
259       }
260     } else if (schedule == kmp_sched_auto) {
261       schedule = kmp_sched_static_chunk;
262       chunk = 1;
263     } else {
264       // ASSERT(LT_FUSSY,
265       //        schedule == kmp_sched_dynamic || schedule == kmp_sched_guided,
266       //        "unknown schedule %d & chunk %lld\n", (int)schedule,
267       //        (long long)chunk);
268     }
269 
270     // init schedules
271     if (schedule == kmp_sched_static_chunk) {
272       ASSERT0(LT_FUSSY, chunk > 0, "bad chunk value");
273       // save sched state
274       DST->ScheduleType = schedule;
275       // save ub
276       DST->LoopUpperBound = ub;
277       // compute static chunk
278       ST stride;
279       int lastiter = 0;
280       ForStaticChunk(lastiter, lb, ub, stride, chunk, threadId, tnum);
281       // save computed params
282       DST->Chunk = chunk;
283       DST->NextLowerBound = lb;
284       DST->Stride = stride;
285     } else if (schedule == kmp_sched_static_balanced_chunk) {
286       ASSERT0(LT_FUSSY, chunk > 0, "bad chunk value");
287       // save sched state
288       DST->ScheduleType = schedule;
289       // save ub
290       DST->LoopUpperBound = ub;
291       // compute static chunk
292       ST stride;
293       int lastiter = 0;
294       // round up to make sure the chunk is enough to cover all iterations
295       T span = (tripCount + tnum - 1) / tnum;
296       // perform chunk adjustment
297       chunk = (span + chunk - 1) & ~(chunk - 1);
298 
299       T oldUb = ub;
300       ForStaticChunk(lastiter, lb, ub, stride, chunk, threadId, tnum);
301       ASSERT0(LT_FUSSY, ub >= lb, "ub must be >= lb.");
302       if (ub > oldUb)
303         ub = oldUb;
304       // save computed params
305       DST->Chunk = chunk;
306       DST->NextLowerBound = lb;
307       DST->Stride = stride;
308     } else if (schedule == kmp_sched_static_nochunk) {
309       ASSERT0(LT_FUSSY, chunk == 0, "bad chunk value");
310       // save sched state
311       DST->ScheduleType = schedule;
312       // save ub
313       DST->LoopUpperBound = ub;
314       // compute static chunk
315       ST stride;
316       int lastiter = 0;
317       ForStaticNoChunk(lastiter, lb, ub, stride, chunk, threadId, tnum);
318       // save computed params
319       DST->Chunk = chunk;
320       DST->NextLowerBound = lb;
321       DST->Stride = stride;
322     } else if (schedule == kmp_sched_dynamic || schedule == kmp_sched_guided) {
323       // save data
324       DST->ScheduleType = schedule;
325       if (chunk < 1)
326         chunk = 1;
327       DST->Chunk = chunk;
328       DST->LoopUpperBound = ub;
329       DST->NextLowerBound = lb;
330       __kmpc_barrier(loc, threadId);
331       if (tid == 0) {
332         Cnt = 0;
333         fence::team(atomic::seq_cst);
334       }
335       __kmpc_barrier(loc, threadId);
336     }
337   }
338 
339   ////////////////////////////////////////////////////////////////////////////////
340   // Support for dispatch next
341 
342   static uint64_t NextIter() {
343     __kmpc_impl_lanemask_t active = mapping::activemask();
344     uint32_t leader = utils::ffs(active) - 1;
345     uint32_t change = utils::popc(active);
346     __kmpc_impl_lanemask_t lane_mask_lt = mapping::lanemaskLT();
347     unsigned int rank = utils::popc(active & lane_mask_lt);
348     uint64_t warp_res = 0;
349     if (rank == 0) {
350       warp_res = atomic::add(&Cnt, change, atomic::seq_cst);
351     }
352     warp_res = utils::shuffle(active, warp_res, leader, mapping::getWarpSize());
353     return warp_res + rank;
354   }
355 
356   static int DynamicNextChunk(T &lb, T &ub, T chunkSize, T loopLowerBound,
357                               T loopUpperBound) {
358     T N = NextIter();
359     lb = loopLowerBound + N * chunkSize;
360     ub = lb + chunkSize - 1; // Clang uses i <= ub
361 
362     // 3 result cases:
363     //  a. lb and ub < loopUpperBound --> NOT_FINISHED
364     //  b. lb < loopUpperBound and ub >= loopUpperBound: last chunk -->
365     //  NOT_FINISHED
366     //  c. lb and ub >= loopUpperBound: empty chunk --> FINISHED
367     // a.
368     if (lb <= loopUpperBound && ub < loopUpperBound) {
369       return NOT_FINISHED;
370     }
371     // b.
372     if (lb <= loopUpperBound) {
373       ub = loopUpperBound;
374       return LAST_CHUNK;
375     }
376     // c. if we are here, we are in case 'c'
377     lb = loopUpperBound + 2;
378     ub = loopUpperBound + 1;
379     return FINISHED;
380   }
381 
382   static int dispatch_next(IdentTy *loc, int32_t gtid, int32_t *plast,
383                            T *plower, T *pupper, ST *pstride,
384                            DynamicScheduleTracker *DST) {
385     // ID of a thread in its own warp
386 
387     // automatically selects thread or warp ID based on selected implementation
388     ASSERT0(LT_FUSSY, gtid < omp_get_num_threads(),
389             "current thread is not needed here; error");
390     // retrieve schedule
391     kmp_sched_t schedule = DST->ScheduleType;
392 
393     // xxx reduce to one
394     if (schedule == kmp_sched_static_chunk ||
395         schedule == kmp_sched_static_nochunk) {
396       T myLb = DST->NextLowerBound;
397       T ub = DST->LoopUpperBound;
398       // finished?
399       if (myLb > ub) {
400         return DISPATCH_FINISHED;
401       }
402       // not finished, save current bounds
403       ST chunk = DST->Chunk;
404       *plower = myLb;
405       T myUb = myLb + chunk - 1; // Clang uses i <= ub
406       if (myUb > ub)
407         myUb = ub;
408       *pupper = myUb;
409       *plast = (int32_t)(myUb == ub);
410 
411       // increment next lower bound by the stride
412       ST stride = DST->Stride;
413       DST->NextLowerBound = myLb + stride;
414       return DISPATCH_NOTFINISHED;
415     }
416     ASSERT0(LT_FUSSY,
417             schedule == kmp_sched_dynamic || schedule == kmp_sched_guided,
418             "bad sched");
419     T myLb, myUb;
420     int finished = DynamicNextChunk(myLb, myUb, DST->Chunk, DST->NextLowerBound,
421                                     DST->LoopUpperBound);
422 
423     if (finished == FINISHED)
424       return DISPATCH_FINISHED;
425 
426     // not finished (either not finished or last chunk)
427     *plast = (int32_t)(finished == LAST_CHUNK);
428     *plower = myLb;
429     *pupper = myUb;
430     *pstride = 1;
431 
432     return DISPATCH_NOTFINISHED;
433   }
434 
435   static void dispatch_fini() {
436     // nothing
437   }
438 
439   ////////////////////////////////////////////////////////////////////////////////
440   // end of template class that encapsulate all the helper functions
441   ////////////////////////////////////////////////////////////////////////////////
442 };
443 
444 ////////////////////////////////////////////////////////////////////////////////
445 // KMP interface implementation (dyn loops)
446 ////////////////////////////////////////////////////////////////////////////////
447 
448 // TODO: Expand the dispatch API to take a DST pointer which can then be
449 //       allocated properly without malloc.
450 // For now, each team will contain an LDS pointer (ThreadDST) to a global array
451 // of references to the DST structs allocated (in global memory) for each thread
452 // in the team. The global memory array is allocated during the init phase if it
453 // was not allocated already and will be deallocated when the dispatch phase
454 // ends:
455 //
456 //  __kmpc_dispatch_init
457 //
458 //  ** Dispatch loop **
459 //
460 //  __kmpc_dispatch_deinit
461 //
462 static DynamicScheduleTracker **SHARED(ThreadDST);
463 
464 // Create a new DST, link the current one, and define the new as current.
465 static DynamicScheduleTracker *pushDST() {
466   int32_t ThreadIndex = mapping::getThreadIdInBlock();
467   // Each block will allocate an array of pointers to DST structs. The array is
468   // equal in length to the number of threads in that block.
469   if (!ThreadDST) {
470     // Allocate global memory array of pointers to DST structs:
471     if (mapping::isMainThreadInGenericMode() || ThreadIndex == 0)
472       ThreadDST = static_cast<DynamicScheduleTracker **>(
473           memory::allocGlobal(mapping::getNumberOfThreadsInBlock() *
474                                   sizeof(DynamicScheduleTracker *),
475                               "new ThreadDST array"));
476     synchronize::threads(atomic::seq_cst);
477 
478     // Initialize the array pointers:
479     ThreadDST[ThreadIndex] = nullptr;
480   }
481 
482   // Create a DST struct for the current thread:
483   DynamicScheduleTracker *NewDST = static_cast<DynamicScheduleTracker *>(
484       memory::allocGlobal(sizeof(DynamicScheduleTracker), "new DST"));
485   *NewDST = DynamicScheduleTracker({0});
486 
487   // Add the new DST struct to the array of DST structs:
488   NewDST->NextDST = ThreadDST[ThreadIndex];
489   ThreadDST[ThreadIndex] = NewDST;
490   return NewDST;
491 }
492 
493 // Return the current DST.
494 static DynamicScheduleTracker *peekDST() {
495   return ThreadDST[mapping::getThreadIdInBlock()];
496 }
497 
498 // Pop the current DST and restore the last one.
499 static void popDST() {
500   int32_t ThreadIndex = mapping::getThreadIdInBlock();
501   DynamicScheduleTracker *CurrentDST = ThreadDST[ThreadIndex];
502   DynamicScheduleTracker *OldDST = CurrentDST->NextDST;
503   memory::freeGlobal(CurrentDST, "remove DST");
504   ThreadDST[ThreadIndex] = OldDST;
505 
506   // Check if we need to deallocate the global array. Ensure all threads
507   // in the block have finished deallocating the individual DSTs.
508   synchronize::threads(atomic::seq_cst);
509   if (!ThreadDST[ThreadIndex] && !ThreadIndex) {
510     memory::freeGlobal(ThreadDST, "remove ThreadDST array");
511     ThreadDST = nullptr;
512   }
513   synchronize::threads(atomic::seq_cst);
514 }
515 
516 void workshare::init(bool IsSPMD) {
517   if (mapping::isInitialThreadInLevel0(IsSPMD))
518     ThreadDST = nullptr;
519 }
520 
521 extern "C" {
522 
523 // init
524 void __kmpc_dispatch_init_4(IdentTy *loc, int32_t tid, int32_t schedule,
525                             int32_t lb, int32_t ub, int32_t st, int32_t chunk) {
526   DynamicScheduleTracker *DST = pushDST();
527   omptarget_nvptx_LoopSupport<int32_t, int32_t>::dispatch_init(
528       loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
529 }
530 
531 void __kmpc_dispatch_init_4u(IdentTy *loc, int32_t tid, int32_t schedule,
532                              uint32_t lb, uint32_t ub, int32_t st,
533                              int32_t chunk) {
534   DynamicScheduleTracker *DST = pushDST();
535   omptarget_nvptx_LoopSupport<uint32_t, int32_t>::dispatch_init(
536       loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
537 }
538 
539 void __kmpc_dispatch_init_8(IdentTy *loc, int32_t tid, int32_t schedule,
540                             int64_t lb, int64_t ub, int64_t st, int64_t chunk) {
541   DynamicScheduleTracker *DST = pushDST();
542   omptarget_nvptx_LoopSupport<int64_t, int64_t>::dispatch_init(
543       loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
544 }
545 
546 void __kmpc_dispatch_init_8u(IdentTy *loc, int32_t tid, int32_t schedule,
547                              uint64_t lb, uint64_t ub, int64_t st,
548                              int64_t chunk) {
549   DynamicScheduleTracker *DST = pushDST();
550   omptarget_nvptx_LoopSupport<uint64_t, int64_t>::dispatch_init(
551       loc, tid, (kmp_sched_t)schedule, lb, ub, st, chunk, DST);
552 }
553 
554 // next
555 int __kmpc_dispatch_next_4(IdentTy *loc, int32_t tid, int32_t *p_last,
556                            int32_t *p_lb, int32_t *p_ub, int32_t *p_st) {
557   DynamicScheduleTracker *DST = peekDST();
558   return omptarget_nvptx_LoopSupport<int32_t, int32_t>::dispatch_next(
559       loc, tid, p_last, p_lb, p_ub, p_st, DST);
560 }
561 
562 int __kmpc_dispatch_next_4u(IdentTy *loc, int32_t tid, int32_t *p_last,
563                             uint32_t *p_lb, uint32_t *p_ub, int32_t *p_st) {
564   DynamicScheduleTracker *DST = peekDST();
565   return omptarget_nvptx_LoopSupport<uint32_t, int32_t>::dispatch_next(
566       loc, tid, p_last, p_lb, p_ub, p_st, DST);
567 }
568 
569 int __kmpc_dispatch_next_8(IdentTy *loc, int32_t tid, int32_t *p_last,
570                            int64_t *p_lb, int64_t *p_ub, int64_t *p_st) {
571   DynamicScheduleTracker *DST = peekDST();
572   return omptarget_nvptx_LoopSupport<int64_t, int64_t>::dispatch_next(
573       loc, tid, p_last, p_lb, p_ub, p_st, DST);
574 }
575 
576 int __kmpc_dispatch_next_8u(IdentTy *loc, int32_t tid, int32_t *p_last,
577                             uint64_t *p_lb, uint64_t *p_ub, int64_t *p_st) {
578   DynamicScheduleTracker *DST = peekDST();
579   return omptarget_nvptx_LoopSupport<uint64_t, int64_t>::dispatch_next(
580       loc, tid, p_last, p_lb, p_ub, p_st, DST);
581 }
582 
583 // fini
584 void __kmpc_dispatch_fini_4(IdentTy *loc, int32_t tid) {
585   omptarget_nvptx_LoopSupport<int32_t, int32_t>::dispatch_fini();
586 }
587 
588 void __kmpc_dispatch_fini_4u(IdentTy *loc, int32_t tid) {
589   omptarget_nvptx_LoopSupport<uint32_t, int32_t>::dispatch_fini();
590 }
591 
592 void __kmpc_dispatch_fini_8(IdentTy *loc, int32_t tid) {
593   omptarget_nvptx_LoopSupport<int64_t, int64_t>::dispatch_fini();
594 }
595 
596 void __kmpc_dispatch_fini_8u(IdentTy *loc, int32_t tid) {
597   omptarget_nvptx_LoopSupport<uint64_t, int64_t>::dispatch_fini();
598 }
599 
600 // deinit
601 void __kmpc_dispatch_deinit(IdentTy *loc, int32_t tid) { popDST(); }
602 
603 ////////////////////////////////////////////////////////////////////////////////
604 // KMP interface implementation (static loops)
605 ////////////////////////////////////////////////////////////////////////////////
606 
607 void __kmpc_for_static_init_4(IdentTy *loc, int32_t global_tid,
608                               int32_t schedtype, int32_t *plastiter,
609                               int32_t *plower, int32_t *pupper,
610                               int32_t *pstride, int32_t incr, int32_t chunk) {
611   omptarget_nvptx_LoopSupport<int32_t, int32_t>::for_static_init(
612       global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
613       mapping::isSPMDMode());
614 }
615 
616 void __kmpc_for_static_init_4u(IdentTy *loc, int32_t global_tid,
617                                int32_t schedtype, int32_t *plastiter,
618                                uint32_t *plower, uint32_t *pupper,
619                                int32_t *pstride, int32_t incr, int32_t chunk) {
620   omptarget_nvptx_LoopSupport<uint32_t, int32_t>::for_static_init(
621       global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
622       mapping::isSPMDMode());
623 }
624 
625 void __kmpc_for_static_init_8(IdentTy *loc, int32_t global_tid,
626                               int32_t schedtype, int32_t *plastiter,
627                               int64_t *plower, int64_t *pupper,
628                               int64_t *pstride, int64_t incr, int64_t chunk) {
629   omptarget_nvptx_LoopSupport<int64_t, int64_t>::for_static_init(
630       global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
631       mapping::isSPMDMode());
632 }
633 
634 void __kmpc_for_static_init_8u(IdentTy *loc, int32_t global_tid,
635                                int32_t schedtype, int32_t *plastiter,
636                                uint64_t *plower, uint64_t *pupper,
637                                int64_t *pstride, int64_t incr, int64_t chunk) {
638   omptarget_nvptx_LoopSupport<uint64_t, int64_t>::for_static_init(
639       global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
640       mapping::isSPMDMode());
641 }
642 
643 void __kmpc_distribute_static_init_4(IdentTy *loc, int32_t global_tid,
644                                      int32_t schedtype, int32_t *plastiter,
645                                      int32_t *plower, int32_t *pupper,
646                                      int32_t *pstride, int32_t incr,
647                                      int32_t chunk) {
648   omptarget_nvptx_LoopSupport<int32_t, int32_t>::for_static_init(
649       global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
650       mapping::isSPMDMode());
651 }
652 
653 void __kmpc_distribute_static_init_4u(IdentTy *loc, int32_t global_tid,
654                                       int32_t schedtype, int32_t *plastiter,
655                                       uint32_t *plower, uint32_t *pupper,
656                                       int32_t *pstride, int32_t incr,
657                                       int32_t chunk) {
658   omptarget_nvptx_LoopSupport<uint32_t, int32_t>::for_static_init(
659       global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
660       mapping::isSPMDMode());
661 }
662 
663 void __kmpc_distribute_static_init_8(IdentTy *loc, int32_t global_tid,
664                                      int32_t schedtype, int32_t *plastiter,
665                                      int64_t *plower, int64_t *pupper,
666                                      int64_t *pstride, int64_t incr,
667                                      int64_t chunk) {
668   omptarget_nvptx_LoopSupport<int64_t, int64_t>::for_static_init(
669       global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
670       mapping::isSPMDMode());
671 }
672 
673 void __kmpc_distribute_static_init_8u(IdentTy *loc, int32_t global_tid,
674                                       int32_t schedtype, int32_t *plastiter,
675                                       uint64_t *plower, uint64_t *pupper,
676                                       int64_t *pstride, int64_t incr,
677                                       int64_t chunk) {
678   omptarget_nvptx_LoopSupport<uint64_t, int64_t>::for_static_init(
679       global_tid, schedtype, plastiter, plower, pupper, pstride, chunk,
680       mapping::isSPMDMode());
681 }
682 
683 void __kmpc_for_static_fini(IdentTy *loc, int32_t global_tid) {}
684 
685 void __kmpc_distribute_static_fini(IdentTy *loc, int32_t global_tid) {}
686 }
687 
688 namespace ompx {
689 
690 /// Helper class to hide the generic loop nest and provide the template argument
691 /// throughout.
692 template <typename Ty> class StaticLoopChunker {
693 
694   /// Generic loop nest that handles block and/or thread distribution in the
695   /// absence of user specified chunk sizes. This implicitly picks a block chunk
696   /// size equal to the number of threads in the block and a thread chunk size
697   /// equal to one. In contrast to the chunked version we can get away with a
698   /// single loop in this case
699   static void NormalizedLoopNestNoChunk(void (*LoopBody)(Ty, void *), void *Arg,
700                                         Ty NumBlocks, Ty BId, Ty NumThreads,
701                                         Ty TId, Ty NumIters,
702                                         bool OneIterationPerThread) {
703     Ty KernelIteration = NumBlocks * NumThreads;
704 
705     // Start index in the normalized space.
706     Ty IV = BId * NumThreads + TId;
707     ASSERT(IV >= 0, "Bad index");
708 
709     // Cover the entire iteration space, assumptions in the caller might allow
710     // to simplify this loop to a conditional.
711     if (IV < NumIters) {
712       do {
713 
714         // Execute the loop body.
715         LoopBody(IV, Arg);
716 
717         // Every thread executed one block and thread chunk now.
718         IV += KernelIteration;
719 
720         if (OneIterationPerThread)
721           return;
722 
723       } while (IV < NumIters);
724     }
725   }
726 
727   /// Generic loop nest that handles block and/or thread distribution in the
728   /// presence of user specified chunk sizes (for at least one of them).
729   static void NormalizedLoopNestChunked(void (*LoopBody)(Ty, void *), void *Arg,
730                                         Ty BlockChunk, Ty NumBlocks, Ty BId,
731                                         Ty ThreadChunk, Ty NumThreads, Ty TId,
732                                         Ty NumIters,
733                                         bool OneIterationPerThread) {
734     Ty KernelIteration = NumBlocks * BlockChunk;
735 
736     // Start index in the chunked space.
737     Ty IV = BId * BlockChunk + TId;
738     ASSERT(IV >= 0, "Bad index");
739 
740     // Cover the entire iteration space, assumptions in the caller might allow
741     // to simplify this loop to a conditional.
742     do {
743 
744       Ty BlockChunkLeft =
745           BlockChunk >= TId * ThreadChunk ? BlockChunk - TId * ThreadChunk : 0;
746       Ty ThreadChunkLeft =
747           ThreadChunk <= BlockChunkLeft ? ThreadChunk : BlockChunkLeft;
748 
749       while (ThreadChunkLeft--) {
750 
751         // Given the blocking it's hard to keep track of what to execute.
752         if (IV >= NumIters)
753           return;
754 
755         // Execute the loop body.
756         LoopBody(IV, Arg);
757 
758         if (OneIterationPerThread)
759           return;
760 
761         ++IV;
762       }
763 
764       IV += KernelIteration;
765 
766     } while (IV < NumIters);
767   }
768 
769 public:
770   /// Worksharing `for`-loop.
771   static void For(IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
772                   Ty NumIters, Ty NumThreads, Ty ThreadChunk) {
773     ASSERT(NumIters >= 0, "Bad iteration count");
774     ASSERT(ThreadChunk >= 0, "Bad thread count");
775 
776     // All threads need to participate but we don't know if we are in a
777     // parallel at all or if the user might have used a `num_threads` clause
778     // on the parallel and reduced the number compared to the block size.
779     // Since nested parallels are possible too we need to get the thread id
780     // from the `omp` getter and not the mapping directly.
781     Ty TId = omp_get_thread_num();
782 
783     // There are no blocks involved here.
784     Ty BlockChunk = 0;
785     Ty NumBlocks = 1;
786     Ty BId = 0;
787 
788     // If the thread chunk is not specified we pick a default now.
789     if (ThreadChunk == 0)
790       ThreadChunk = 1;
791 
792     // If we know we have more threads than iterations we can indicate that to
793     // avoid an outer loop.
794     bool OneIterationPerThread = false;
795     if (config::getAssumeThreadsOversubscription()) {
796       ASSERT(NumThreads >= NumIters, "Broken assumption");
797       OneIterationPerThread = true;
798     }
799 
800     if (ThreadChunk != 1)
801       NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
802                                 ThreadChunk, NumThreads, TId, NumIters,
803                                 OneIterationPerThread);
804     else
805       NormalizedLoopNestNoChunk(LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
806                                 NumIters, OneIterationPerThread);
807   }
808 
809   /// Worksharing `distrbute`-loop.
810   static void Distribute(IdentTy *Loc, void (*LoopBody)(Ty, void *), void *Arg,
811                          Ty NumIters, Ty BlockChunk) {
812     ASSERT(icv::Level == 0, "Bad distribute");
813     ASSERT(icv::ActiveLevel == 0, "Bad distribute");
814     ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
815     ASSERT(state::ParallelTeamSize == 1, "Bad distribute");
816 
817     ASSERT(NumIters >= 0, "Bad iteration count");
818     ASSERT(BlockChunk >= 0, "Bad block count");
819 
820     // There are no threads involved here.
821     Ty ThreadChunk = 0;
822     Ty NumThreads = 1;
823     Ty TId = 0;
824     ASSERT(TId == mapping::getThreadIdInBlock(), "Bad thread id");
825 
826     // All teams need to participate.
827     Ty NumBlocks = mapping::getNumberOfBlocksInKernel();
828     Ty BId = mapping::getBlockIdInKernel();
829 
830     // If the block chunk is not specified we pick a default now.
831     if (BlockChunk == 0)
832       BlockChunk = NumThreads;
833 
834     // If we know we have more blocks than iterations we can indicate that to
835     // avoid an outer loop.
836     bool OneIterationPerThread = false;
837     if (config::getAssumeTeamsOversubscription()) {
838       ASSERT(NumBlocks >= NumIters, "Broken assumption");
839       OneIterationPerThread = true;
840     }
841 
842     if (BlockChunk != NumThreads)
843       NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
844                                 ThreadChunk, NumThreads, TId, NumIters,
845                                 OneIterationPerThread);
846     else
847       NormalizedLoopNestNoChunk(LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
848                                 NumIters, OneIterationPerThread);
849 
850     ASSERT(icv::Level == 0, "Bad distribute");
851     ASSERT(icv::ActiveLevel == 0, "Bad distribute");
852     ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
853     ASSERT(state::ParallelTeamSize == 1, "Bad distribute");
854   }
855 
856   /// Worksharing `distrbute parallel for`-loop.
857   static void DistributeFor(IdentTy *Loc, void (*LoopBody)(Ty, void *),
858                             void *Arg, Ty NumIters, Ty NumThreads,
859                             Ty BlockChunk, Ty ThreadChunk) {
860     ASSERT(icv::Level == 1, "Bad distribute");
861     ASSERT(icv::ActiveLevel == 1, "Bad distribute");
862     ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
863 
864     ASSERT(NumIters >= 0, "Bad iteration count");
865     ASSERT(BlockChunk >= 0, "Bad block count");
866     ASSERT(ThreadChunk >= 0, "Bad thread count");
867 
868     // All threads need to participate but the user might have used a
869     // `num_threads` clause on the parallel and reduced the number compared to
870     // the block size.
871     Ty TId = mapping::getThreadIdInBlock();
872 
873     // All teams need to participate.
874     Ty NumBlocks = mapping::getNumberOfBlocksInKernel();
875     Ty BId = mapping::getBlockIdInKernel();
876 
877     // If the block chunk is not specified we pick a default now.
878     if (BlockChunk == 0)
879       BlockChunk = NumThreads;
880 
881     // If the thread chunk is not specified we pick a default now.
882     if (ThreadChunk == 0)
883       ThreadChunk = 1;
884 
885     // If we know we have more threads (across all blocks) than iterations we
886     // can indicate that to avoid an outer loop.
887     bool OneIterationPerThread = false;
888     if (config::getAssumeTeamsOversubscription() &
889         config::getAssumeThreadsOversubscription()) {
890       OneIterationPerThread = true;
891       ASSERT(NumBlocks * NumThreads >= NumIters, "Broken assumption");
892     }
893 
894     if (BlockChunk != NumThreads || ThreadChunk != 1)
895       NormalizedLoopNestChunked(LoopBody, Arg, BlockChunk, NumBlocks, BId,
896                                 ThreadChunk, NumThreads, TId, NumIters,
897                                 OneIterationPerThread);
898     else
899       NormalizedLoopNestNoChunk(LoopBody, Arg, NumBlocks, BId, NumThreads, TId,
900                                 NumIters, OneIterationPerThread);
901 
902     ASSERT(icv::Level == 1, "Bad distribute");
903     ASSERT(icv::ActiveLevel == 1, "Bad distribute");
904     ASSERT(state::ParallelRegionFn == nullptr, "Bad distribute");
905   }
906 };
907 
908 } // namespace ompx
909 
910 #define OMP_LOOP_ENTRY(BW, TY)                                                 \
911   [[gnu::flatten, clang::always_inline]] void                                  \
912       __kmpc_distribute_for_static_loop##BW(                                   \
913           IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters,       \
914           TY num_threads, TY block_chunk, TY thread_chunk) {                   \
915     ompx::StaticLoopChunker<TY>::DistributeFor(                                \
916         loc, fn, arg, num_iters + 1, num_threads, block_chunk, thread_chunk);  \
917   }                                                                            \
918   [[gnu::flatten, clang::always_inline]] void                                  \
919       __kmpc_distribute_static_loop##BW(IdentTy *loc, void (*fn)(TY, void *),  \
920                                         void *arg, TY num_iters,               \
921                                         TY block_chunk) {                      \
922     ompx::StaticLoopChunker<TY>::Distribute(loc, fn, arg, num_iters + 1,       \
923                                             block_chunk);                      \
924   }                                                                            \
925   [[gnu::flatten, clang::always_inline]] void __kmpc_for_static_loop##BW(      \
926       IdentTy *loc, void (*fn)(TY, void *), void *arg, TY num_iters,           \
927       TY num_threads, TY thread_chunk) {                                       \
928     ompx::StaticLoopChunker<TY>::For(loc, fn, arg, num_iters + 1, num_threads, \
929                                      thread_chunk);                            \
930   }
931 
932 extern "C" {
933 OMP_LOOP_ENTRY(_4, int32_t)
934 OMP_LOOP_ENTRY(_4u, uint32_t)
935 OMP_LOOP_ENTRY(_8, int64_t)
936 OMP_LOOP_ENTRY(_8u, uint64_t)
937 }
938 
939 #pragma omp end declare target
940