xref: /llvm-project/offload/DeviceRTL/include/Synchronization.h (revision 3274bf6b4282a0dafd4b5a2efa09824e5ca417d0)
1 //===- Synchronization.h - OpenMP synchronization utilities ------- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 //
10 //===----------------------------------------------------------------------===//
11 
12 #ifndef OMPTARGET_DEVICERTL_SYNCHRONIZATION_H
13 #define OMPTARGET_DEVICERTL_SYNCHRONIZATION_H
14 
15 #include "DeviceTypes.h"
16 #include "DeviceUtils.h"
17 
18 #pragma omp begin declare target device_type(nohost)
19 
20 namespace ompx {
21 namespace atomic {
22 
23 enum OrderingTy {
24   relaxed = __ATOMIC_RELAXED,
25   aquire = __ATOMIC_ACQUIRE,
26   release = __ATOMIC_RELEASE,
27   acq_rel = __ATOMIC_ACQ_REL,
28   seq_cst = __ATOMIC_SEQ_CST,
29 };
30 
31 enum MemScopeTy {
32   system = __MEMORY_SCOPE_SYSTEM,
33   device = __MEMORY_SCOPE_DEVICE,
34   workgroup = __MEMORY_SCOPE_WRKGRP,
35   wavefront = __MEMORY_SCOPE_WVFRNT,
36   single = __MEMORY_SCOPE_SINGLE,
37 };
38 
39 /// Atomically increment \p *Addr and wrap at \p V with \p Ordering semantics.
40 uint32_t inc(uint32_t *Addr, uint32_t V, OrderingTy Ordering,
41              MemScopeTy MemScope = MemScopeTy::device);
42 
43 /// Atomically perform <op> on \p V and \p *Addr with \p Ordering semantics. The
44 /// result is stored in \p *Addr;
45 /// {
46 
47 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
48 bool cas(Ty *Address, V ExpectedV, V DesiredV, atomic::OrderingTy OrderingSucc,
49          atomic::OrderingTy OrderingFail,
50          MemScopeTy MemScope = MemScopeTy::device) {
51   return __scoped_atomic_compare_exchange(Address, &ExpectedV, &DesiredV, false,
52                                           OrderingSucc, OrderingFail, MemScope);
53 }
54 
55 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
56 V add(Ty *Address, V Val, atomic::OrderingTy Ordering,
57       MemScopeTy MemScope = MemScopeTy::device) {
58   return __scoped_atomic_fetch_add(Address, Val, Ordering, MemScope);
59 }
60 
61 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
62 V load(Ty *Address, atomic::OrderingTy Ordering,
63        MemScopeTy MemScope = MemScopeTy::device) {
64   return __scoped_atomic_load_n(Address, Ordering, MemScope);
65 }
66 
67 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
68 void store(Ty *Address, V Val, atomic::OrderingTy Ordering,
69            MemScopeTy MemScope = MemScopeTy::device) {
70   __scoped_atomic_store_n(Address, Val, Ordering, MemScope);
71 }
72 
73 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
74 V mul(Ty *Address, V Val, atomic::OrderingTy Ordering,
75       MemScopeTy MemScope = MemScopeTy::device) {
76   Ty TypedCurrentVal, TypedResultVal, TypedNewVal;
77   bool Success;
78   do {
79     TypedCurrentVal = atomic::load(Address, Ordering);
80     TypedNewVal = TypedCurrentVal * Val;
81     Success = atomic::cas(Address, TypedCurrentVal, TypedNewVal, Ordering,
82                           atomic::relaxed, MemScope);
83   } while (!Success);
84   return TypedResultVal;
85 }
86 
87 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
88 utils::enable_if_t<!utils::is_floating_point_v<V>, V>
89 max(Ty *Address, V Val, atomic::OrderingTy Ordering,
90     MemScopeTy MemScope = MemScopeTy::device) {
91   return __scoped_atomic_fetch_max(Address, Val, Ordering, MemScope);
92 }
93 
94 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
95 utils::enable_if_t<utils::is_same_v<V, float>, V>
96 max(Ty *Address, V Val, atomic::OrderingTy Ordering,
97     MemScopeTy MemScope = MemScopeTy::device) {
98   if (Val >= 0)
99     return utils::bitCast<float>(max(
100         (int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering, MemScope));
101   return utils::bitCast<float>(min(
102       (uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering, MemScope));
103 }
104 
105 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
106 utils::enable_if_t<utils::is_same_v<V, double>, V>
107 max(Ty *Address, V Val, atomic::OrderingTy Ordering,
108     MemScopeTy MemScope = MemScopeTy::device) {
109   if (Val >= 0)
110     return utils::bitCast<double>(max(
111         (int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering, MemScope));
112   return utils::bitCast<double>(min(
113       (uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering, MemScope));
114 }
115 
116 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
117 utils::enable_if_t<!utils::is_floating_point_v<V>, V>
118 min(Ty *Address, V Val, atomic::OrderingTy Ordering,
119     MemScopeTy MemScope = MemScopeTy::device) {
120   return __scoped_atomic_fetch_min(Address, Val, Ordering, MemScope);
121 }
122 
123 // TODO: Implement this with __atomic_fetch_max and remove the duplication.
124 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
125 utils::enable_if_t<utils::is_same_v<V, float>, V>
126 min(Ty *Address, V Val, atomic::OrderingTy Ordering,
127     MemScopeTy MemScope = MemScopeTy::device) {
128   if (Val >= 0)
129     return utils::bitCast<float>(min(
130         (int32_t *)Address, utils::bitCast<int32_t>(Val), Ordering, MemScope));
131   return utils::bitCast<float>(max(
132       (uint32_t *)Address, utils::bitCast<uint32_t>(Val), Ordering, MemScope));
133 }
134 
135 // TODO: Implement this with __atomic_fetch_max and remove the duplication.
136 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
137 utils::enable_if_t<utils::is_same_v<V, double>, V>
138 min(Ty *Address, utils::remove_addrspace_t<Ty> Val, atomic::OrderingTy Ordering,
139     MemScopeTy MemScope = MemScopeTy::device) {
140   if (Val >= 0)
141     return utils::bitCast<double>(min(
142         (int64_t *)Address, utils::bitCast<int64_t>(Val), Ordering, MemScope));
143   return utils::bitCast<double>(max(
144       (uint64_t *)Address, utils::bitCast<uint64_t>(Val), Ordering, MemScope));
145 }
146 
147 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
148 V bit_or(Ty *Address, V Val, atomic::OrderingTy Ordering,
149          MemScopeTy MemScope = MemScopeTy::device) {
150   return __scoped_atomic_fetch_or(Address, Val, Ordering, MemScope);
151 }
152 
153 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
154 V bit_and(Ty *Address, V Val, atomic::OrderingTy Ordering,
155           MemScopeTy MemScope = MemScopeTy::device) {
156   return __scoped_atomic_fetch_and(Address, Val, Ordering, MemScope);
157 }
158 
159 template <typename Ty, typename V = utils::remove_addrspace_t<Ty>>
160 V bit_xor(Ty *Address, V Val, atomic::OrderingTy Ordering,
161           MemScopeTy MemScope = MemScopeTy::device) {
162   return __scoped_atomic_fetch_xor(Address, Val, Ordering, MemScope);
163 }
164 
165 static inline uint32_t
166 atomicExchange(uint32_t *Address, uint32_t Val, atomic::OrderingTy Ordering,
167                MemScopeTy MemScope = MemScopeTy::device) {
168   uint32_t R;
169   __scoped_atomic_exchange(Address, &Val, &R, Ordering, MemScope);
170   return R;
171 }
172 
173 ///}
174 
175 } // namespace atomic
176 
177 namespace synchronize {
178 
179 /// Initialize the synchronization machinery. Must be called by all threads.
180 void init(bool IsSPMD);
181 
182 /// Synchronize all threads in a warp identified by \p Mask.
183 void warp(LaneMaskTy Mask);
184 
185 /// Synchronize all threads in a block and perform a fence before and after the
186 /// barrier according to \p Ordering. Note that the fence might be part of the
187 /// barrier.
188 void threads(atomic::OrderingTy Ordering);
189 
190 /// Synchronizing threads is allowed even if they all hit different instances of
191 /// `synchronize::threads()`. However, `synchronize::threadsAligned()` is more
192 /// restrictive in that it requires all threads to hit the same instance. The
193 /// noinline is removed by the openmp-opt pass and helps to preserve the
194 /// information till then.
195 ///{
196 
197 /// Synchronize all threads in a block, they are reaching the same instruction
198 /// (hence all threads in the block are "aligned"). Also perform a fence before
199 /// and after the barrier according to \p Ordering. Note that the
200 /// fence might be part of the barrier if the target offers this.
201 [[gnu::noinline, omp::assume("ompx_aligned_barrier")]] void
202 threadsAligned(atomic::OrderingTy Ordering);
203 
204 ///}
205 
206 } // namespace synchronize
207 
208 namespace fence {
209 
210 /// Memory fence with \p Ordering semantics for the team.
211 void team(atomic::OrderingTy Ordering);
212 
213 /// Memory fence with \p Ordering semantics for the contention group.
214 void kernel(atomic::OrderingTy Ordering);
215 
216 /// Memory fence with \p Ordering semantics for the system.
217 void system(atomic::OrderingTy Ordering);
218 
219 } // namespace fence
220 
221 } // namespace ompx
222 
223 #pragma omp end declare target
224 
225 #endif
226