xref: /llvm-project/flang/runtime/CUDA/allocatable.cpp (revision 4cb2a519db10f54815c8a4ccd5accbedc1cdfd07)
1 //===-- runtime/CUDA/allocatable.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "flang/Runtime/CUDA/allocatable.h"
10 #include "../assign-impl.h"
11 #include "../stat.h"
12 #include "../terminator.h"
13 #include "flang/Runtime/CUDA/common.h"
14 #include "flang/Runtime/CUDA/descriptor.h"
15 #include "flang/Runtime/CUDA/memmove-function.h"
16 #include "flang/Runtime/allocatable.h"
17 
18 #include "cuda_runtime.h"
19 
20 namespace Fortran::runtime::cuda {
21 
22 extern "C" {
23 RT_EXT_API_GROUP_BEGIN
24 
25 int RTDEF(CUFAllocatableAllocateSync)(Descriptor &desc, int64_t stream,
26     bool hasStat, const Descriptor *errMsg, const char *sourceFile,
27     int sourceLine) {
28   int stat{RTNAME(CUFAllocatableAllocate)(
29       desc, stream, hasStat, errMsg, sourceFile, sourceLine)};
30 #ifndef RT_DEVICE_COMPILATION
31   // Descriptor synchronization is only done when the allocation is done
32   // from the host.
33   if (stat == StatOk) {
34     void *deviceAddr{
35         RTNAME(CUFGetDeviceAddress)((void *)&desc, sourceFile, sourceLine)};
36     RTNAME(CUFDescriptorSync)
37     ((Descriptor *)deviceAddr, &desc, sourceFile, sourceLine);
38   }
39 #endif
40   return stat;
41 }
42 
43 int RTDEF(CUFAllocatableAllocate)(Descriptor &desc, int64_t stream,
44     bool hasStat, const Descriptor *errMsg, const char *sourceFile,
45     int sourceLine) {
46   if (desc.HasAddendum()) {
47     Terminator terminator{sourceFile, sourceLine};
48     // TODO: This require a bit more work to set the correct type descriptor
49     // address
50     terminator.Crash(
51         "not yet implemented: CUDA descriptor allocation with addendum");
52   }
53   // Perform the standard allocation.
54   int stat{RTNAME(AllocatableAllocate)(
55       desc, hasStat, errMsg, sourceFile, sourceLine)};
56   return stat;
57 }
58 
59 int RTDEF(CUFAllocatableAllocateSource)(Descriptor &alloc,
60     const Descriptor &source, int64_t stream, bool hasStat,
61     const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
62   int stat{RTNAME(CUFAllocatableAllocate)(
63       alloc, stream, hasStat, errMsg, sourceFile, sourceLine)};
64   if (stat == StatOk) {
65     Terminator terminator{sourceFile, sourceLine};
66     Fortran::runtime::DoFromSourceAssign(
67         alloc, source, terminator, &MemmoveHostToDevice);
68   }
69   return stat;
70 }
71 
72 int RTDEF(CUFAllocatableAllocateSourceSync)(Descriptor &alloc,
73     const Descriptor &source, int64_t stream, bool hasStat,
74     const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
75   int stat{RTNAME(CUFAllocatableAllocateSync)(
76       alloc, stream, hasStat, errMsg, sourceFile, sourceLine)};
77   if (stat == StatOk) {
78     Terminator terminator{sourceFile, sourceLine};
79     Fortran::runtime::DoFromSourceAssign(
80         alloc, source, terminator, &MemmoveHostToDevice);
81   }
82   return stat;
83 }
84 
85 int RTDEF(CUFAllocatableDeallocate)(Descriptor &desc, bool hasStat,
86     const Descriptor *errMsg, const char *sourceFile, int sourceLine) {
87   // Perform the standard allocation.
88   int stat{RTNAME(AllocatableDeallocate)(
89       desc, hasStat, errMsg, sourceFile, sourceLine)};
90 #ifndef RT_DEVICE_COMPILATION
91   // Descriptor synchronization is only done when the deallocation is done
92   // from the host.
93   if (stat == StatOk) {
94     void *deviceAddr{
95         RTNAME(CUFGetDeviceAddress)((void *)&desc, sourceFile, sourceLine)};
96     RTNAME(CUFDescriptorSync)
97     ((Descriptor *)deviceAddr, &desc, sourceFile, sourceLine);
98   }
99 #endif
100   return stat;
101 }
102 
103 RT_EXT_API_GROUP_END
104 
105 } // extern "C"
106 
107 } // namespace Fortran::runtime::cuda
108