1 //===-- runtime/CUDA/descriptor.cpp ---------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "flang/Runtime/CUDA/descriptor.h" 10 #include "../terminator.h" 11 #include "flang/Runtime/CUDA/allocator.h" 12 #include "flang/Runtime/CUDA/common.h" 13 #include "flang/Runtime/descriptor.h" 14 15 #include "cuda_runtime.h" 16 17 namespace Fortran::runtime::cuda { 18 extern "C" { 19 RT_EXT_API_GROUP_BEGIN 20 21 Descriptor *RTDEF(CUFAllocDescriptor)( 22 std::size_t sizeInBytes, const char *sourceFile, int sourceLine) { 23 return reinterpret_cast<Descriptor *>(CUFAllocManaged(sizeInBytes)); 24 } 25 26 void RTDEF(CUFFreeDescriptor)( 27 Descriptor *desc, const char *sourceFile, int sourceLine) { 28 CUFFreeManaged(reinterpret_cast<void *>(desc)); 29 } 30 31 void *RTDEF(CUFGetDeviceAddress)( 32 void *hostPtr, const char *sourceFile, int sourceLine) { 33 Terminator terminator{sourceFile, sourceLine}; 34 void *p; 35 CUDA_REPORT_IF_ERROR(cudaGetSymbolAddress((void **)&p, hostPtr)); 36 if (!p) { 37 terminator.Crash("Could not retrieve symbol's address"); 38 } 39 return p; 40 } 41 42 void RTDEF(CUFDescriptorSync)(Descriptor *dst, const Descriptor *src, 43 const char *sourceFile, int sourceLine) { 44 std::size_t count{src->SizeInBytes()}; 45 CUDA_REPORT_IF_ERROR(cudaMemcpy( 46 (void *)dst, (const void *)src, count, cudaMemcpyHostToDevice)); 47 } 48 49 void RTDEF(CUFSyncGlobalDescriptor)( 50 void *hostPtr, const char *sourceFile, int sourceLine) { 51 void *devAddr{RTNAME(CUFGetDeviceAddress)(hostPtr, sourceFile, sourceLine)}; 52 RTNAME(CUFDescriptorSync) 53 ((Descriptor *)devAddr, (Descriptor *)hostPtr, sourceFile, sourceLine); 54 } 55 56 RT_EXT_API_GROUP_END 57 } 58 } // namespace Fortran::runtime::cuda 59