xref: /llvm-project/offload/src/KernelLanguage/API.cpp (revision 80525dfcde5bf8aae6ab6b0810124ba502de6096)
1*80525dfcSJohannes Doerfert //===------ API.cpp - Kernel Language (CUDA/HIP) entry points ----- C++ -*-===//
2*80525dfcSJohannes Doerfert //
3*80525dfcSJohannes Doerfert // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4*80525dfcSJohannes Doerfert // See https://llvm.org/LICENSE.txt for license information.
5*80525dfcSJohannes Doerfert // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6*80525dfcSJohannes Doerfert //
7*80525dfcSJohannes Doerfert //===----------------------------------------------------------------------===//
8*80525dfcSJohannes Doerfert //
9*80525dfcSJohannes Doerfert //===----------------------------------------------------------------------===//
10*80525dfcSJohannes Doerfert 
11*80525dfcSJohannes Doerfert #include "Shared/APITypes.h"
12*80525dfcSJohannes Doerfert 
13*80525dfcSJohannes Doerfert #include <cstdio>
14*80525dfcSJohannes Doerfert 
15*80525dfcSJohannes Doerfert struct dim3 {
16*80525dfcSJohannes Doerfert   unsigned x = 0, y = 0, z = 0;
17*80525dfcSJohannes Doerfert };
18*80525dfcSJohannes Doerfert 
19*80525dfcSJohannes Doerfert struct __omp_kernel_t {
20*80525dfcSJohannes Doerfert   dim3 __grid_size;
21*80525dfcSJohannes Doerfert   dim3 __block_size;
22*80525dfcSJohannes Doerfert   size_t __shared_memory;
23*80525dfcSJohannes Doerfert 
24*80525dfcSJohannes Doerfert   void *__stream;
25*80525dfcSJohannes Doerfert };
26*80525dfcSJohannes Doerfert 
27*80525dfcSJohannes Doerfert static __omp_kernel_t __current_kernel = {};
28*80525dfcSJohannes Doerfert #pragma omp threadprivate(__current_kernel);
29*80525dfcSJohannes Doerfert 
30*80525dfcSJohannes Doerfert extern "C" {
31*80525dfcSJohannes Doerfert 
32*80525dfcSJohannes Doerfert // TODO: There is little reason we need to keep these names or the way calls are
33*80525dfcSJohannes Doerfert // issued. For now we do to avoid modifying Clang's CUDA codegen. Unclear when
34*80525dfcSJohannes Doerfert // we actually need to push/pop configurations.
35*80525dfcSJohannes Doerfert unsigned __llvmPushCallConfiguration(dim3 __grid_size, dim3 __block_size,
36*80525dfcSJohannes Doerfert                                      size_t __shared_memory, void *__stream) {
37*80525dfcSJohannes Doerfert   __omp_kernel_t &__kernel = __current_kernel;
38*80525dfcSJohannes Doerfert   __kernel.__grid_size = __grid_size;
39*80525dfcSJohannes Doerfert   __kernel.__block_size = __block_size;
40*80525dfcSJohannes Doerfert   __kernel.__shared_memory = __shared_memory;
41*80525dfcSJohannes Doerfert   __kernel.__stream = __stream;
42*80525dfcSJohannes Doerfert   return 0;
43*80525dfcSJohannes Doerfert }
44*80525dfcSJohannes Doerfert 
45*80525dfcSJohannes Doerfert unsigned __llvmPopCallConfiguration(dim3 *__grid_size, dim3 *__block_size,
46*80525dfcSJohannes Doerfert                                     size_t *__shared_memory, void *__stream) {
47*80525dfcSJohannes Doerfert   __omp_kernel_t &__kernel = __current_kernel;
48*80525dfcSJohannes Doerfert   *__grid_size = __kernel.__grid_size;
49*80525dfcSJohannes Doerfert   *__block_size = __kernel.__block_size;
50*80525dfcSJohannes Doerfert   *__shared_memory = __kernel.__shared_memory;
51*80525dfcSJohannes Doerfert   *((void **)__stream) = __kernel.__stream;
52*80525dfcSJohannes Doerfert   return 0;
53*80525dfcSJohannes Doerfert }
54*80525dfcSJohannes Doerfert 
55*80525dfcSJohannes Doerfert int __tgt_target_kernel(void *Loc, int64_t DeviceId, int32_t NumTeams,
56*80525dfcSJohannes Doerfert                         int32_t ThreadLimit, const void *HostPtr,
57*80525dfcSJohannes Doerfert                         KernelArgsTy *Args);
58*80525dfcSJohannes Doerfert 
59*80525dfcSJohannes Doerfert unsigned llvmLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
60*80525dfcSJohannes Doerfert                           void *args, size_t sharedMem, void *stream) {
61*80525dfcSJohannes Doerfert   KernelArgsTy Args = {};
62*80525dfcSJohannes Doerfert   Args.DynCGroupMem = sharedMem;
63*80525dfcSJohannes Doerfert   Args.NumTeams[0] = gridDim.x;
64*80525dfcSJohannes Doerfert   Args.NumTeams[1] = gridDim.y;
65*80525dfcSJohannes Doerfert   Args.NumTeams[2] = gridDim.z;
66*80525dfcSJohannes Doerfert   Args.ThreadLimit[0] = blockDim.x;
67*80525dfcSJohannes Doerfert   Args.ThreadLimit[1] = blockDim.y;
68*80525dfcSJohannes Doerfert   Args.ThreadLimit[2] = blockDim.z;
69*80525dfcSJohannes Doerfert   Args.ArgPtrs = reinterpret_cast<void **>(args);
70*80525dfcSJohannes Doerfert   Args.Flags.IsCUDA = true;
71*80525dfcSJohannes Doerfert   return __tgt_target_kernel(nullptr, 0, gridDim.x, blockDim.x, func, &Args);
72*80525dfcSJohannes Doerfert }
73*80525dfcSJohannes Doerfert }
74