xref: /openbsd-src/gnu/llvm/libcxx/src/support/win32/thread_win32.cpp (revision 46035553bfdd96e63c94e32da0210227ec2e3cf1)
1 // -*- C++ -*-
2 //===-------------------- support/win32/thread_win32.cpp ------------------===//
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 
10 #include <__threading_support>
11 #include <windows.h>
12 #include <process.h>
13 #include <fibersapi.h>
14 
15 _LIBCPP_BEGIN_NAMESPACE_STD
16 
17 static_assert(sizeof(__libcpp_mutex_t) == sizeof(SRWLOCK), "");
18 static_assert(alignof(__libcpp_mutex_t) == alignof(SRWLOCK), "");
19 
20 static_assert(sizeof(__libcpp_recursive_mutex_t) == sizeof(CRITICAL_SECTION),
21               "");
22 static_assert(alignof(__libcpp_recursive_mutex_t) == alignof(CRITICAL_SECTION),
23               "");
24 
25 static_assert(sizeof(__libcpp_condvar_t) == sizeof(CONDITION_VARIABLE), "");
26 static_assert(alignof(__libcpp_condvar_t) == alignof(CONDITION_VARIABLE), "");
27 
28 static_assert(sizeof(__libcpp_exec_once_flag) == sizeof(INIT_ONCE), "");
29 static_assert(alignof(__libcpp_exec_once_flag) == alignof(INIT_ONCE), "");
30 
31 static_assert(sizeof(__libcpp_thread_id) == sizeof(DWORD), "");
32 static_assert(alignof(__libcpp_thread_id) == alignof(DWORD), "");
33 
34 static_assert(sizeof(__libcpp_thread_t) == sizeof(HANDLE), "");
35 static_assert(alignof(__libcpp_thread_t) == alignof(HANDLE), "");
36 
37 static_assert(sizeof(__libcpp_tls_key) == sizeof(DWORD), "");
38 static_assert(alignof(__libcpp_tls_key) == alignof(DWORD), "");
39 
40 // Mutex
41 int __libcpp_recursive_mutex_init(__libcpp_recursive_mutex_t *__m)
42 {
43   InitializeCriticalSection((LPCRITICAL_SECTION)__m);
44   return 0;
45 }
46 
47 int __libcpp_recursive_mutex_lock(__libcpp_recursive_mutex_t *__m)
48 {
49   EnterCriticalSection((LPCRITICAL_SECTION)__m);
50   return 0;
51 }
52 
53 bool __libcpp_recursive_mutex_trylock(__libcpp_recursive_mutex_t *__m)
54 {
55   return TryEnterCriticalSection((LPCRITICAL_SECTION)__m) != 0;
56 }
57 
58 int __libcpp_recursive_mutex_unlock(__libcpp_recursive_mutex_t *__m)
59 {
60   LeaveCriticalSection((LPCRITICAL_SECTION)__m);
61   return 0;
62 }
63 
64 int __libcpp_recursive_mutex_destroy(__libcpp_recursive_mutex_t *__m)
65 {
66   DeleteCriticalSection((LPCRITICAL_SECTION)__m);
67   return 0;
68 }
69 
70 int __libcpp_mutex_lock(__libcpp_mutex_t *__m)
71 {
72   AcquireSRWLockExclusive((PSRWLOCK)__m);
73   return 0;
74 }
75 
76 bool __libcpp_mutex_trylock(__libcpp_mutex_t *__m)
77 {
78   return TryAcquireSRWLockExclusive((PSRWLOCK)__m) != 0;
79 }
80 
81 int __libcpp_mutex_unlock(__libcpp_mutex_t *__m)
82 {
83   ReleaseSRWLockExclusive((PSRWLOCK)__m);
84   return 0;
85 }
86 
87 int __libcpp_mutex_destroy(__libcpp_mutex_t *__m)
88 {
89   static_cast<void>(__m);
90   return 0;
91 }
92 
93 // Condition Variable
94 int __libcpp_condvar_signal(__libcpp_condvar_t *__cv)
95 {
96   WakeConditionVariable((PCONDITION_VARIABLE)__cv);
97   return 0;
98 }
99 
100 int __libcpp_condvar_broadcast(__libcpp_condvar_t *__cv)
101 {
102   WakeAllConditionVariable((PCONDITION_VARIABLE)__cv);
103   return 0;
104 }
105 
106 int __libcpp_condvar_wait(__libcpp_condvar_t *__cv, __libcpp_mutex_t *__m)
107 {
108   SleepConditionVariableSRW((PCONDITION_VARIABLE)__cv, (PSRWLOCK)__m, INFINITE, 0);
109   return 0;
110 }
111 
112 int __libcpp_condvar_timedwait(__libcpp_condvar_t *__cv, __libcpp_mutex_t *__m,
113                                __libcpp_timespec_t *__ts)
114 {
115   using namespace _VSTD::chrono;
116 
117   auto duration = seconds(__ts->tv_sec) + nanoseconds(__ts->tv_nsec);
118   auto abstime =
119       system_clock::time_point(duration_cast<system_clock::duration>(duration));
120   auto timeout_ms = duration_cast<milliseconds>(abstime - system_clock::now());
121 
122   if (!SleepConditionVariableSRW((PCONDITION_VARIABLE)__cv, (PSRWLOCK)__m,
123                                  timeout_ms.count() > 0 ? timeout_ms.count()
124                                                         : 0,
125                                  0))
126     {
127       auto __ec = GetLastError();
128       return __ec == ERROR_TIMEOUT ? ETIMEDOUT : __ec;
129     }
130   return 0;
131 }
132 
133 int __libcpp_condvar_destroy(__libcpp_condvar_t *__cv)
134 {
135   static_cast<void>(__cv);
136   return 0;
137 }
138 
139 // Execute Once
140 static inline _LIBCPP_INLINE_VISIBILITY BOOL CALLBACK
141 __libcpp_init_once_execute_once_thunk(PINIT_ONCE __init_once, PVOID __parameter,
142                                       PVOID *__context)
143 {
144   static_cast<void>(__init_once);
145   static_cast<void>(__context);
146 
147   void (*init_routine)(void) = reinterpret_cast<void (*)(void)>(__parameter);
148   init_routine();
149   return TRUE;
150 }
151 
152 int __libcpp_execute_once(__libcpp_exec_once_flag *__flag,
153                           void (*__init_routine)(void))
154 {
155   if (!InitOnceExecuteOnce((PINIT_ONCE)__flag, __libcpp_init_once_execute_once_thunk,
156                            reinterpret_cast<void *>(__init_routine), NULL))
157     return GetLastError();
158   return 0;
159 }
160 
161 // Thread ID
162 bool __libcpp_thread_id_equal(__libcpp_thread_id __lhs,
163                               __libcpp_thread_id __rhs)
164 {
165   return __lhs == __rhs;
166 }
167 
168 bool __libcpp_thread_id_less(__libcpp_thread_id __lhs, __libcpp_thread_id __rhs)
169 {
170   return __lhs < __rhs;
171 }
172 
173 // Thread
174 struct __libcpp_beginthreadex_thunk_data
175 {
176   void *(*__func)(void *);
177   void *__arg;
178 };
179 
180 static inline _LIBCPP_INLINE_VISIBILITY unsigned WINAPI
181 __libcpp_beginthreadex_thunk(void *__raw_data)
182 {
183   auto *__data =
184       static_cast<__libcpp_beginthreadex_thunk_data *>(__raw_data);
185   auto *__func = __data->__func;
186   void *__arg = __data->__arg;
187   delete __data;
188   return static_cast<unsigned>(reinterpret_cast<uintptr_t>(__func(__arg)));
189 }
190 
191 bool __libcpp_thread_isnull(const __libcpp_thread_t *__t) {
192   return *__t == 0;
193 }
194 
195 int __libcpp_thread_create(__libcpp_thread_t *__t, void *(*__func)(void *),
196                            void *__arg)
197 {
198   auto *__data = new __libcpp_beginthreadex_thunk_data;
199   __data->__func = __func;
200   __data->__arg = __arg;
201 
202   *__t = reinterpret_cast<HANDLE>(_beginthreadex(nullptr, 0,
203                                                  __libcpp_beginthreadex_thunk,
204                                                  __data, 0, nullptr));
205 
206   if (*__t)
207     return 0;
208   return GetLastError();
209 }
210 
211 __libcpp_thread_id __libcpp_thread_get_current_id()
212 {
213   return GetCurrentThreadId();
214 }
215 
216 __libcpp_thread_id __libcpp_thread_get_id(const __libcpp_thread_t *__t)
217 {
218   return GetThreadId(*__t);
219 }
220 
221 int __libcpp_thread_join(__libcpp_thread_t *__t)
222 {
223   if (WaitForSingleObjectEx(*__t, INFINITE, FALSE) == WAIT_FAILED)
224     return GetLastError();
225   if (!CloseHandle(*__t))
226     return GetLastError();
227   return 0;
228 }
229 
230 int __libcpp_thread_detach(__libcpp_thread_t *__t)
231 {
232   if (!CloseHandle(*__t))
233     return GetLastError();
234   return 0;
235 }
236 
237 void __libcpp_thread_yield()
238 {
239   SwitchToThread();
240 }
241 
242 void __libcpp_thread_sleep_for(const chrono::nanoseconds& __ns)
243 {
244   using namespace chrono;
245   // round-up to the nearest milisecond
246   milliseconds __ms =
247       duration_cast<milliseconds>(__ns + chrono::nanoseconds(999999));
248   // FIXME(compnerd) this should be an alertable sleep (WFSO or SleepEx)
249   Sleep(__ms.count());
250 }
251 
252 // Thread Local Storage
253 int __libcpp_tls_create(__libcpp_tls_key* __key,
254                         void(_LIBCPP_TLS_DESTRUCTOR_CC* __at_exit)(void*))
255 {
256   DWORD index = FlsAlloc(__at_exit);
257   if (index == FLS_OUT_OF_INDEXES)
258     return GetLastError();
259   *__key = index;
260   return 0;
261 }
262 
263 void *__libcpp_tls_get(__libcpp_tls_key __key)
264 {
265   return FlsGetValue(__key);
266 }
267 
268 int __libcpp_tls_set(__libcpp_tls_key __key, void *__p)
269 {
270   if (!FlsSetValue(__key, __p))
271     return GetLastError();
272   return 0;
273 }
274 
275 _LIBCPP_END_NAMESPACE_STD
276