xref: /llvm-project/flang/runtime/CUDA/descriptor.cpp (revision 6dcd2b035da34fa53693b401139a419adb7342db)
1 //===-- runtime/CUDA/descriptor.cpp ---------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Runtime/CUDA/descriptor.h"
10 #include "../terminator.h"
11 #include "flang/Runtime/CUDA/allocator.h"
12 #include "flang/Runtime/CUDA/common.h"
13 #include "flang/Runtime/descriptor.h"
14 
15 #include "cuda_runtime.h"
16 
17 namespace Fortran::runtime::cuda {
18 extern "C" {
19 RT_EXT_API_GROUP_BEGIN
20 
21 Descriptor *RTDEF(CUFAllocDescriptor)(
22     std::size_t sizeInBytes, const char *sourceFile, int sourceLine) {
23   return reinterpret_cast<Descriptor *>(CUFAllocManaged(sizeInBytes));
24 }
25 
26 void RTDEF(CUFFreeDescriptor)(
27     Descriptor *desc, const char *sourceFile, int sourceLine) {
28   CUFFreeManaged(reinterpret_cast<void *>(desc));
29 }
30 
31 void *RTDEF(CUFGetDeviceAddress)(
32     void *hostPtr, const char *sourceFile, int sourceLine) {
33   Terminator terminator{sourceFile, sourceLine};
34   void *p;
35   CUDA_REPORT_IF_ERROR(cudaGetSymbolAddress((void **)&p, hostPtr));
36   if (!p) {
37     terminator.Crash("Could not retrieve symbol's address");
38   }
39   return p;
40 }
41 
42 void RTDEF(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src,
43     const char *sourceFile, int sourceLine) {
44   std::size_t count{src->SizeInBytes()};
45   CUDA_REPORT_IF_ERROR(cudaMemcpy(
46       (void *)dst, (const void *)src, count, cudaMemcpyHostToDevice));
47 }
48 
49 void RTDEF(CUFSyncGlobalDescriptor)(
50     void *hostPtr, const char *sourceFile, int sourceLine) {
51   void *devAddr{RTNAME(CUFGetDeviceAddress)(hostPtr, sourceFile, sourceLine)};
52   RTNAME(CUFDescriptorSync)
53   ((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine);
54 }
55 
56 RT_EXT_API_GROUP_END
57 }
58 } // namespace Fortran::runtime::cuda
59