1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15
16 #include <atomic>
17 #include <cassert>
18 #include <condition_variable>
19 #include <functional>
20 #include <iostream>
21 #include <mutex>
22 #include <thread>
23 #include <vector>
24
25 #include "llvm/ADT/StringMap.h"
26 #include "llvm/Support/ThreadPool.h"
27
28 using namespace mlir::runtime;
29
30 //===----------------------------------------------------------------------===//
31 // Async runtime API.
32 //===----------------------------------------------------------------------===//
33
34 namespace mlir {
35 namespace runtime {
36 namespace {
37
38 // Forward declare class defined below.
39 class RefCounted;
40
41 // -------------------------------------------------------------------------- //
42 // AsyncRuntime orchestrates all async operations and Async runtime API is built
43 // on top of the default runtime instance.
44 // -------------------------------------------------------------------------- //
45
46 class AsyncRuntime {
47 public:
AsyncRuntime()48 AsyncRuntime() : numRefCountedObjects(0) {}
49
~AsyncRuntime()50 ~AsyncRuntime() {
51 threadPool.wait(); // wait for the completion of all async tasks
52 assert(getNumRefCountedObjects() == 0 &&
53 "all ref counted objects must be destroyed");
54 }
55
getNumRefCountedObjects()56 int64_t getNumRefCountedObjects() {
57 return numRefCountedObjects.load(std::memory_order_relaxed);
58 }
59
getThreadPool()60 llvm::ThreadPoolInterface &getThreadPool() { return threadPool; }
61
62 private:
63 friend class RefCounted;
64
65 // Count the total number of reference counted objects in this instance
66 // of an AsyncRuntime. For debugging purposes only.
addNumRefCountedObjects()67 void addNumRefCountedObjects() {
68 numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
69 }
dropNumRefCountedObjects()70 void dropNumRefCountedObjects() {
71 numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
72 }
73
74 std::atomic<int64_t> numRefCountedObjects;
75 llvm::DefaultThreadPool threadPool;
76 };
77
78 // -------------------------------------------------------------------------- //
79 // A state of the async runtime value (token, value or group).
80 // -------------------------------------------------------------------------- //
81
82 class State {
83 public:
84 enum StateEnum : int8_t {
85 // The underlying value is not yet available for consumption.
86 kUnavailable = 0,
87 // The underlying value is available for consumption. This state can not
88 // transition to any other state.
89 kAvailable = 1,
90 // This underlying value is available and contains an error. This state can
91 // not transition to any other state.
92 kError = 2,
93 };
94
State(StateEnum s)95 /* implicit */ State(StateEnum s) : state(s) {}
operator StateEnum()96 /* implicit */ operator StateEnum() { return state; }
97
isUnavailable() const98 bool isUnavailable() const { return state == kUnavailable; }
isAvailable() const99 bool isAvailable() const { return state == kAvailable; }
isError() const100 bool isError() const { return state == kError; }
isAvailableOrError() const101 bool isAvailableOrError() const { return isAvailable() || isError(); }
102
debug() const103 const char *debug() const {
104 switch (state) {
105 case kUnavailable:
106 return "unavailable";
107 case kAvailable:
108 return "available";
109 case kError:
110 return "error";
111 }
112 }
113
114 private:
115 StateEnum state;
116 };
117
118 // -------------------------------------------------------------------------- //
119 // A base class for all reference counted objects created by the async runtime.
120 // -------------------------------------------------------------------------- //
121
122 class RefCounted {
123 public:
RefCounted(AsyncRuntime * runtime,int64_t refCount=1)124 RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
125 : runtime(runtime), refCount(refCount) {
126 runtime->addNumRefCountedObjects();
127 }
128
~RefCounted()129 virtual ~RefCounted() {
130 assert(refCount.load() == 0 && "reference count must be zero");
131 runtime->dropNumRefCountedObjects();
132 }
133
134 RefCounted(const RefCounted &) = delete;
135 RefCounted &operator=(const RefCounted &) = delete;
136
addRef(int64_t count=1)137 void addRef(int64_t count = 1) { refCount.fetch_add(count); }
138
dropRef(int64_t count=1)139 void dropRef(int64_t count = 1) {
140 int64_t previous = refCount.fetch_sub(count);
141 assert(previous >= count && "reference count should not go below zero");
142 if (previous == count)
143 destroy();
144 }
145
146 protected:
destroy()147 virtual void destroy() { delete this; }
148
149 private:
150 AsyncRuntime *runtime;
151 std::atomic<int64_t> refCount;
152 };
153
154 } // namespace
155
156 // Returns the default per-process instance of an async runtime.
getDefaultAsyncRuntimeInstance()157 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
158 static auto runtime = std::make_unique<AsyncRuntime>();
159 return runtime;
160 }
161
resetDefaultAsyncRuntime()162 static void resetDefaultAsyncRuntime() {
163 return getDefaultAsyncRuntimeInstance().reset();
164 }
165
getDefaultAsyncRuntime()166 static AsyncRuntime *getDefaultAsyncRuntime() {
167 return getDefaultAsyncRuntimeInstance().get();
168 }
169
170 // Async token provides a mechanism to signal asynchronous operation completion.
171 struct AsyncToken : public RefCounted {
172 // AsyncToken created with a reference count of 2 because it will be returned
173 // to the `async.execute` caller and also will be later on emplaced by the
174 // asynchronously executed task. If the caller immediately will drop its
175 // reference we must ensure that the token will be alive until the
176 // asynchronous operation is completed.
AsyncTokenmlir::runtime::AsyncToken177 AsyncToken(AsyncRuntime *runtime)
178 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
179
180 std::atomic<State::StateEnum> state;
181
182 // Pending awaiters are guarded by a mutex.
183 std::mutex mu;
184 std::condition_variable cv;
185 std::vector<std::function<void()>> awaiters;
186 };
187
188 // Async value provides a mechanism to access the result of asynchronous
189 // operations. It owns the storage that is used to store/load the value of the
190 // underlying type, and a flag to signal if the value is ready or not.
191 struct AsyncValue : public RefCounted {
192 // AsyncValue similar to an AsyncToken created with a reference count of 2.
AsyncValuemlir::runtime::AsyncValue193 AsyncValue(AsyncRuntime *runtime, int64_t size)
194 : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
195 storage(size) {}
196
197 std::atomic<State::StateEnum> state;
198
199 // Use vector of bytes to store async value payload.
200 std::vector<std::byte> storage;
201
202 // Pending awaiters are guarded by a mutex.
203 std::mutex mu;
204 std::condition_variable cv;
205 std::vector<std::function<void()>> awaiters;
206 };
207
208 // Async group provides a mechanism to group together multiple async tokens or
209 // values to await on all of them together (wait for the completion of all
210 // tokens or values added to the group).
211 struct AsyncGroup : public RefCounted {
AsyncGroupmlir::runtime::AsyncGroup212 AsyncGroup(AsyncRuntime *runtime, int64_t size)
213 : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
214
215 std::atomic<int> pendingTokens;
216 std::atomic<int> numErrors;
217 std::atomic<int> rank;
218
219 // Pending awaiters are guarded by a mutex.
220 std::mutex mu;
221 std::condition_variable cv;
222 std::vector<std::function<void()>> awaiters;
223 };
224
225 // Adds references to reference counted runtime object.
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr,int64_t count)226 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
227 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
228 refCounted->addRef(count);
229 }
230
231 // Drops references from reference counted runtime object.
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr,int64_t count)232 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
233 RefCounted *refCounted = static_cast<RefCounted *>(ptr);
234 refCounted->dropRef(count);
235 }
236
237 // Creates a new `async.token` in not-ready state.
mlirAsyncRuntimeCreateToken()238 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
239 AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
240 return token;
241 }
242
243 // Creates a new `async.value` in not-ready state.
mlirAsyncRuntimeCreateValue(int64_t size)244 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
245 AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
246 return value;
247 }
248
249 // Create a new `async.group` in empty state.
mlirAsyncRuntimeCreateGroup(int64_t size)250 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
251 AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
252 return group;
253 }
254
mlirAsyncRuntimeAddTokenToGroup(AsyncToken * token,AsyncGroup * group)255 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
256 AsyncGroup *group) {
257 std::unique_lock<std::mutex> lockToken(token->mu);
258 std::unique_lock<std::mutex> lockGroup(group->mu);
259
260 // Get the rank of the token inside the group before we drop the reference.
261 int rank = group->rank.fetch_add(1);
262
263 auto onTokenReady = [group, token]() {
264 // Increment the number of errors in the group.
265 if (State(token->state).isError())
266 group->numErrors.fetch_add(1);
267
268 // If pending tokens go below zero it means that more tokens than the group
269 // size were added to this group.
270 assert(group->pendingTokens > 0 && "wrong group size");
271
272 // Run all group awaiters if it was the last token in the group.
273 if (group->pendingTokens.fetch_sub(1) == 1) {
274 group->cv.notify_all();
275 for (auto &awaiter : group->awaiters)
276 awaiter();
277 }
278 };
279
280 if (State(token->state).isAvailableOrError()) {
281 // Update group pending tokens immediately and maybe run awaiters.
282 onTokenReady();
283
284 } else {
285 // Update group pending tokens when token will become ready. Because this
286 // will happen asynchronously we must ensure that `group` is alive until
287 // then, and re-ackquire the lock.
288 group->addRef();
289
290 token->awaiters.emplace_back([group, onTokenReady]() {
291 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
292 {
293 std::unique_lock<std::mutex> lockGroup(group->mu);
294 onTokenReady();
295 }
296 group->dropRef();
297 });
298 }
299
300 return rank;
301 }
302
303 // Switches `async.token` to available or error state (terminatl state) and runs
304 // all awaiters.
setTokenState(AsyncToken * token,State state)305 static void setTokenState(AsyncToken *token, State state) {
306 assert(state.isAvailableOrError() && "must be terminal state");
307 assert(State(token->state).isUnavailable() && "token must be unavailable");
308
309 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
310 {
311 std::unique_lock<std::mutex> lock(token->mu);
312 token->state = state;
313 token->cv.notify_all();
314 for (auto &awaiter : token->awaiters)
315 awaiter();
316 }
317
318 // Async tokens created with a ref count `2` to keep token alive until the
319 // async task completes. Drop this reference explicitly when token emplaced.
320 token->dropRef();
321 }
322
setValueState(AsyncValue * value,State state)323 static void setValueState(AsyncValue *value, State state) {
324 assert(state.isAvailableOrError() && "must be terminal state");
325 assert(State(value->state).isUnavailable() && "value must be unavailable");
326
327 // Make sure that `dropRef` does not destroy the mutex owned by the lock.
328 {
329 std::unique_lock<std::mutex> lock(value->mu);
330 value->state = state;
331 value->cv.notify_all();
332 for (auto &awaiter : value->awaiters)
333 awaiter();
334 }
335
336 // Async values created with a ref count `2` to keep value alive until the
337 // async task completes. Drop this reference explicitly when value emplaced.
338 value->dropRef();
339 }
340
mlirAsyncRuntimeEmplaceToken(AsyncToken * token)341 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
342 setTokenState(token, State::kAvailable);
343 }
344
mlirAsyncRuntimeEmplaceValue(AsyncValue * value)345 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
346 setValueState(value, State::kAvailable);
347 }
348
mlirAsyncRuntimeSetTokenError(AsyncToken * token)349 extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
350 setTokenState(token, State::kError);
351 }
352
mlirAsyncRuntimeSetValueError(AsyncValue * value)353 extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
354 setValueState(value, State::kError);
355 }
356
mlirAsyncRuntimeIsTokenError(AsyncToken * token)357 extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
358 return State(token->state).isError();
359 }
360
mlirAsyncRuntimeIsValueError(AsyncValue * value)361 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
362 return State(value->state).isError();
363 }
364
mlirAsyncRuntimeIsGroupError(AsyncGroup * group)365 extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
366 return group->numErrors.load() > 0;
367 }
368
mlirAsyncRuntimeAwaitToken(AsyncToken * token)369 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
370 std::unique_lock<std::mutex> lock(token->mu);
371 if (!State(token->state).isAvailableOrError())
372 token->cv.wait(
373 lock, [token] { return State(token->state).isAvailableOrError(); });
374 }
375
mlirAsyncRuntimeAwaitValue(AsyncValue * value)376 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
377 std::unique_lock<std::mutex> lock(value->mu);
378 if (!State(value->state).isAvailableOrError())
379 value->cv.wait(
380 lock, [value] { return State(value->state).isAvailableOrError(); });
381 }
382
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup * group)383 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
384 std::unique_lock<std::mutex> lock(group->mu);
385 if (group->pendingTokens != 0)
386 group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
387 }
388
389 // Returns a pointer to the storage owned by the async value.
mlirAsyncRuntimeGetValueStorage(AsyncValue * value)390 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
391 assert(!State(value->state).isError() && "unexpected error state");
392 return value->storage.data();
393 }
394
mlirAsyncRuntimeExecute(CoroHandle handle,CoroResume resume)395 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
396 auto *runtime = getDefaultAsyncRuntime();
397 runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
398 }
399
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken * token,CoroHandle handle,CoroResume resume)400 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
401 CoroHandle handle,
402 CoroResume resume) {
403 auto execute = [handle, resume]() { (*resume)(handle); };
404 std::unique_lock<std::mutex> lock(token->mu);
405 if (State(token->state).isAvailableOrError()) {
406 lock.unlock();
407 execute();
408 } else {
409 token->awaiters.emplace_back([execute]() { execute(); });
410 }
411 }
412
mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue * value,CoroHandle handle,CoroResume resume)413 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
414 CoroHandle handle,
415 CoroResume resume) {
416 auto execute = [handle, resume]() { (*resume)(handle); };
417 std::unique_lock<std::mutex> lock(value->mu);
418 if (State(value->state).isAvailableOrError()) {
419 lock.unlock();
420 execute();
421 } else {
422 value->awaiters.emplace_back([execute]() { execute(); });
423 }
424 }
425
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup * group,CoroHandle handle,CoroResume resume)426 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
427 CoroHandle handle,
428 CoroResume resume) {
429 auto execute = [handle, resume]() { (*resume)(handle); };
430 std::unique_lock<std::mutex> lock(group->mu);
431 if (group->pendingTokens == 0) {
432 lock.unlock();
433 execute();
434 } else {
435 group->awaiters.emplace_back([execute]() { execute(); });
436 }
437 }
438
mlirAsyncRuntimGetNumWorkerThreads()439 extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
440 return getDefaultAsyncRuntime()->getThreadPool().getMaxConcurrency();
441 }
442
443 //===----------------------------------------------------------------------===//
444 // Small async runtime support library for testing.
445 //===----------------------------------------------------------------------===//
446
mlirAsyncRuntimePrintCurrentThreadId()447 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
448 static thread_local std::thread::id thisId = std::this_thread::get_id();
449 std::cout << "Current thread id: " << thisId << '\n';
450 }
451
452 //===----------------------------------------------------------------------===//
453 // MLIR ExecutionEngine dynamic library integration.
454 //===----------------------------------------------------------------------===//
455
456 // Visual Studio had a bug that fails to compile nested generic lambdas
457 // inside an `extern "C"` function.
458 // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
459 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
460 // a work around for older versions of Visual Studio.
461 // NOLINTNEXTLINE(*-identifier-naming): externally called.
462 extern "C" MLIR_ASYNC_RUNTIME_EXPORT void
463 __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols);
464
465 // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_execution_engine_init(llvm::StringMap<void * > & exportSymbols)466 void __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols) {
467 auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
468 assert(exportSymbols.count(name) == 0 && "symbol already exists");
469 exportSymbols[name] = reinterpret_cast<void *>(ptr);
470 };
471
472 exportSymbol("mlirAsyncRuntimeAddRef",
473 &mlir::runtime::mlirAsyncRuntimeAddRef);
474 exportSymbol("mlirAsyncRuntimeDropRef",
475 &mlir::runtime::mlirAsyncRuntimeDropRef);
476 exportSymbol("mlirAsyncRuntimeExecute",
477 &mlir::runtime::mlirAsyncRuntimeExecute);
478 exportSymbol("mlirAsyncRuntimeGetValueStorage",
479 &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
480 exportSymbol("mlirAsyncRuntimeCreateToken",
481 &mlir::runtime::mlirAsyncRuntimeCreateToken);
482 exportSymbol("mlirAsyncRuntimeCreateValue",
483 &mlir::runtime::mlirAsyncRuntimeCreateValue);
484 exportSymbol("mlirAsyncRuntimeEmplaceToken",
485 &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
486 exportSymbol("mlirAsyncRuntimeEmplaceValue",
487 &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
488 exportSymbol("mlirAsyncRuntimeSetTokenError",
489 &mlir::runtime::mlirAsyncRuntimeSetTokenError);
490 exportSymbol("mlirAsyncRuntimeSetValueError",
491 &mlir::runtime::mlirAsyncRuntimeSetValueError);
492 exportSymbol("mlirAsyncRuntimeIsTokenError",
493 &mlir::runtime::mlirAsyncRuntimeIsTokenError);
494 exportSymbol("mlirAsyncRuntimeIsValueError",
495 &mlir::runtime::mlirAsyncRuntimeIsValueError);
496 exportSymbol("mlirAsyncRuntimeIsGroupError",
497 &mlir::runtime::mlirAsyncRuntimeIsGroupError);
498 exportSymbol("mlirAsyncRuntimeAwaitToken",
499 &mlir::runtime::mlirAsyncRuntimeAwaitToken);
500 exportSymbol("mlirAsyncRuntimeAwaitValue",
501 &mlir::runtime::mlirAsyncRuntimeAwaitValue);
502 exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
503 &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
504 exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
505 &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
506 exportSymbol("mlirAsyncRuntimeCreateGroup",
507 &mlir::runtime::mlirAsyncRuntimeCreateGroup);
508 exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
509 &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
510 exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
511 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
512 exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
513 &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
514 exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
515 &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
516 exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
517 &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
518 }
519
520 // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_execution_engine_destroy()521 extern "C" MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_destroy() {
522 resetDefaultAsyncRuntime();
523 }
524
525 } // namespace runtime
526 } // namespace mlir
527