xref: /llvm-project/mlir/lib/ExecutionEngine/AsyncRuntime.cpp (revision 716042a63f26cd020eb72960f72fa97b9a197382)
1 //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements basic Async runtime API for supporting Async dialect
10 // to LLVM dialect lowering.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/ExecutionEngine/AsyncRuntime.h"
15 
16 #include <atomic>
17 #include <cassert>
18 #include <condition_variable>
19 #include <functional>
20 #include <iostream>
21 #include <mutex>
22 #include <thread>
23 #include <vector>
24 
25 #include "llvm/ADT/StringMap.h"
26 #include "llvm/Support/ThreadPool.h"
27 
28 using namespace mlir::runtime;
29 
30 //===----------------------------------------------------------------------===//
31 // Async runtime API.
32 //===----------------------------------------------------------------------===//
33 
34 namespace mlir {
35 namespace runtime {
36 namespace {
37 
38 // Forward declare class defined below.
39 class RefCounted;
40 
41 // -------------------------------------------------------------------------- //
42 // AsyncRuntime orchestrates all async operations and Async runtime API is built
43 // on top of the default runtime instance.
44 // -------------------------------------------------------------------------- //
45 
46 class AsyncRuntime {
47 public:
AsyncRuntime()48   AsyncRuntime() : numRefCountedObjects(0) {}
49 
~AsyncRuntime()50   ~AsyncRuntime() {
51     threadPool.wait(); // wait for the completion of all async tasks
52     assert(getNumRefCountedObjects() == 0 &&
53            "all ref counted objects must be destroyed");
54   }
55 
getNumRefCountedObjects()56   int64_t getNumRefCountedObjects() {
57     return numRefCountedObjects.load(std::memory_order_relaxed);
58   }
59 
getThreadPool()60   llvm::ThreadPoolInterface &getThreadPool() { return threadPool; }
61 
62 private:
63   friend class RefCounted;
64 
65   // Count the total number of reference counted objects in this instance
66   // of an AsyncRuntime. For debugging purposes only.
addNumRefCountedObjects()67   void addNumRefCountedObjects() {
68     numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
69   }
dropNumRefCountedObjects()70   void dropNumRefCountedObjects() {
71     numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
72   }
73 
74   std::atomic<int64_t> numRefCountedObjects;
75   llvm::DefaultThreadPool threadPool;
76 };
77 
78 // -------------------------------------------------------------------------- //
79 // A state of the async runtime value (token, value or group).
80 // -------------------------------------------------------------------------- //
81 
82 class State {
83 public:
84   enum StateEnum : int8_t {
85     // The underlying value is not yet available for consumption.
86     kUnavailable = 0,
87     // The underlying value is available for consumption. This state can not
88     // transition to any other state.
89     kAvailable = 1,
90     // This underlying value is available and contains an error. This state can
91     // not transition to any other state.
92     kError = 2,
93   };
94 
State(StateEnum s)95   /* implicit */ State(StateEnum s) : state(s) {}
operator StateEnum()96   /* implicit */ operator StateEnum() { return state; }
97 
isUnavailable() const98   bool isUnavailable() const { return state == kUnavailable; }
isAvailable() const99   bool isAvailable() const { return state == kAvailable; }
isError() const100   bool isError() const { return state == kError; }
isAvailableOrError() const101   bool isAvailableOrError() const { return isAvailable() || isError(); }
102 
debug() const103   const char *debug() const {
104     switch (state) {
105     case kUnavailable:
106       return "unavailable";
107     case kAvailable:
108       return "available";
109     case kError:
110       return "error";
111     }
112   }
113 
114 private:
115   StateEnum state;
116 };
117 
118 // -------------------------------------------------------------------------- //
119 // A base class for all reference counted objects created by the async runtime.
120 // -------------------------------------------------------------------------- //
121 
122 class RefCounted {
123 public:
RefCounted(AsyncRuntime * runtime,int64_t refCount=1)124   RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
125       : runtime(runtime), refCount(refCount) {
126     runtime->addNumRefCountedObjects();
127   }
128 
~RefCounted()129   virtual ~RefCounted() {
130     assert(refCount.load() == 0 && "reference count must be zero");
131     runtime->dropNumRefCountedObjects();
132   }
133 
134   RefCounted(const RefCounted &) = delete;
135   RefCounted &operator=(const RefCounted &) = delete;
136 
addRef(int64_t count=1)137   void addRef(int64_t count = 1) { refCount.fetch_add(count); }
138 
dropRef(int64_t count=1)139   void dropRef(int64_t count = 1) {
140     int64_t previous = refCount.fetch_sub(count);
141     assert(previous >= count && "reference count should not go below zero");
142     if (previous == count)
143       destroy();
144   }
145 
146 protected:
destroy()147   virtual void destroy() { delete this; }
148 
149 private:
150   AsyncRuntime *runtime;
151   std::atomic<int64_t> refCount;
152 };
153 
154 } // namespace
155 
156 // Returns the default per-process instance of an async runtime.
getDefaultAsyncRuntimeInstance()157 static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
158   static auto runtime = std::make_unique<AsyncRuntime>();
159   return runtime;
160 }
161 
resetDefaultAsyncRuntime()162 static void resetDefaultAsyncRuntime() {
163   return getDefaultAsyncRuntimeInstance().reset();
164 }
165 
getDefaultAsyncRuntime()166 static AsyncRuntime *getDefaultAsyncRuntime() {
167   return getDefaultAsyncRuntimeInstance().get();
168 }
169 
170 // Async token provides a mechanism to signal asynchronous operation completion.
171 struct AsyncToken : public RefCounted {
172   // AsyncToken created with a reference count of 2 because it will be returned
173   // to the `async.execute` caller and also will be later on emplaced by the
174   // asynchronously executed task. If the caller immediately will drop its
175   // reference we must ensure that the token will be alive until the
176   // asynchronous operation is completed.
AsyncTokenmlir::runtime::AsyncToken177   AsyncToken(AsyncRuntime *runtime)
178       : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
179 
180   std::atomic<State::StateEnum> state;
181 
182   // Pending awaiters are guarded by a mutex.
183   std::mutex mu;
184   std::condition_variable cv;
185   std::vector<std::function<void()>> awaiters;
186 };
187 
188 // Async value provides a mechanism to access the result of asynchronous
189 // operations. It owns the storage that is used to store/load the value of the
190 // underlying type, and a flag to signal if the value is ready or not.
191 struct AsyncValue : public RefCounted {
192   // AsyncValue similar to an AsyncToken created with a reference count of 2.
AsyncValuemlir::runtime::AsyncValue193   AsyncValue(AsyncRuntime *runtime, int64_t size)
194       : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
195         storage(size) {}
196 
197   std::atomic<State::StateEnum> state;
198 
199   // Use vector of bytes to store async value payload.
200   std::vector<std::byte> storage;
201 
202   // Pending awaiters are guarded by a mutex.
203   std::mutex mu;
204   std::condition_variable cv;
205   std::vector<std::function<void()>> awaiters;
206 };
207 
208 // Async group provides a mechanism to group together multiple async tokens or
209 // values to await on all of them together (wait for the completion of all
210 // tokens or values added to the group).
211 struct AsyncGroup : public RefCounted {
AsyncGroupmlir::runtime::AsyncGroup212   AsyncGroup(AsyncRuntime *runtime, int64_t size)
213       : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
214 
215   std::atomic<int> pendingTokens;
216   std::atomic<int> numErrors;
217   std::atomic<int> rank;
218 
219   // Pending awaiters are guarded by a mutex.
220   std::mutex mu;
221   std::condition_variable cv;
222   std::vector<std::function<void()>> awaiters;
223 };
224 
225 // Adds references to reference counted runtime object.
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr,int64_t count)226 extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
227   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
228   refCounted->addRef(count);
229 }
230 
231 // Drops references from reference counted runtime object.
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr,int64_t count)232 extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
233   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
234   refCounted->dropRef(count);
235 }
236 
237 // Creates a new `async.token` in not-ready state.
mlirAsyncRuntimeCreateToken()238 extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
239   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
240   return token;
241 }
242 
243 // Creates a new `async.value` in not-ready state.
mlirAsyncRuntimeCreateValue(int64_t size)244 extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
245   AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
246   return value;
247 }
248 
249 // Create a new `async.group` in empty state.
mlirAsyncRuntimeCreateGroup(int64_t size)250 extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
251   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
252   return group;
253 }
254 
mlirAsyncRuntimeAddTokenToGroup(AsyncToken * token,AsyncGroup * group)255 extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
256                                                    AsyncGroup *group) {
257   std::unique_lock<std::mutex> lockToken(token->mu);
258   std::unique_lock<std::mutex> lockGroup(group->mu);
259 
260   // Get the rank of the token inside the group before we drop the reference.
261   int rank = group->rank.fetch_add(1);
262 
263   auto onTokenReady = [group, token]() {
264     // Increment the number of errors in the group.
265     if (State(token->state).isError())
266       group->numErrors.fetch_add(1);
267 
268     // If pending tokens go below zero it means that more tokens than the group
269     // size were added to this group.
270     assert(group->pendingTokens > 0 && "wrong group size");
271 
272     // Run all group awaiters if it was the last token in the group.
273     if (group->pendingTokens.fetch_sub(1) == 1) {
274       group->cv.notify_all();
275       for (auto &awaiter : group->awaiters)
276         awaiter();
277     }
278   };
279 
280   if (State(token->state).isAvailableOrError()) {
281     // Update group pending tokens immediately and maybe run awaiters.
282     onTokenReady();
283 
284   } else {
285     // Update group pending tokens when token will become ready. Because this
286     // will happen asynchronously we must ensure that `group` is alive until
287     // then, and re-ackquire the lock.
288     group->addRef();
289 
290     token->awaiters.emplace_back([group, onTokenReady]() {
291       // Make sure that `dropRef` does not destroy the mutex owned by the lock.
292       {
293         std::unique_lock<std::mutex> lockGroup(group->mu);
294         onTokenReady();
295       }
296       group->dropRef();
297     });
298   }
299 
300   return rank;
301 }
302 
303 // Switches `async.token` to available or error state (terminatl state) and runs
304 // all awaiters.
setTokenState(AsyncToken * token,State state)305 static void setTokenState(AsyncToken *token, State state) {
306   assert(state.isAvailableOrError() && "must be terminal state");
307   assert(State(token->state).isUnavailable() && "token must be unavailable");
308 
309   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
310   {
311     std::unique_lock<std::mutex> lock(token->mu);
312     token->state = state;
313     token->cv.notify_all();
314     for (auto &awaiter : token->awaiters)
315       awaiter();
316   }
317 
318   // Async tokens created with a ref count `2` to keep token alive until the
319   // async task completes. Drop this reference explicitly when token emplaced.
320   token->dropRef();
321 }
322 
setValueState(AsyncValue * value,State state)323 static void setValueState(AsyncValue *value, State state) {
324   assert(state.isAvailableOrError() && "must be terminal state");
325   assert(State(value->state).isUnavailable() && "value must be unavailable");
326 
327   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
328   {
329     std::unique_lock<std::mutex> lock(value->mu);
330     value->state = state;
331     value->cv.notify_all();
332     for (auto &awaiter : value->awaiters)
333       awaiter();
334   }
335 
336   // Async values created with a ref count `2` to keep value alive until the
337   // async task completes. Drop this reference explicitly when value emplaced.
338   value->dropRef();
339 }
340 
mlirAsyncRuntimeEmplaceToken(AsyncToken * token)341 extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
342   setTokenState(token, State::kAvailable);
343 }
344 
mlirAsyncRuntimeEmplaceValue(AsyncValue * value)345 extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
346   setValueState(value, State::kAvailable);
347 }
348 
mlirAsyncRuntimeSetTokenError(AsyncToken * token)349 extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
350   setTokenState(token, State::kError);
351 }
352 
mlirAsyncRuntimeSetValueError(AsyncValue * value)353 extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
354   setValueState(value, State::kError);
355 }
356 
mlirAsyncRuntimeIsTokenError(AsyncToken * token)357 extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
358   return State(token->state).isError();
359 }
360 
mlirAsyncRuntimeIsValueError(AsyncValue * value)361 extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
362   return State(value->state).isError();
363 }
364 
mlirAsyncRuntimeIsGroupError(AsyncGroup * group)365 extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
366   return group->numErrors.load() > 0;
367 }
368 
mlirAsyncRuntimeAwaitToken(AsyncToken * token)369 extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
370   std::unique_lock<std::mutex> lock(token->mu);
371   if (!State(token->state).isAvailableOrError())
372     token->cv.wait(
373         lock, [token] { return State(token->state).isAvailableOrError(); });
374 }
375 
mlirAsyncRuntimeAwaitValue(AsyncValue * value)376 extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
377   std::unique_lock<std::mutex> lock(value->mu);
378   if (!State(value->state).isAvailableOrError())
379     value->cv.wait(
380         lock, [value] { return State(value->state).isAvailableOrError(); });
381 }
382 
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup * group)383 extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
384   std::unique_lock<std::mutex> lock(group->mu);
385   if (group->pendingTokens != 0)
386     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
387 }
388 
389 // Returns a pointer to the storage owned by the async value.
mlirAsyncRuntimeGetValueStorage(AsyncValue * value)390 extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
391   assert(!State(value->state).isError() && "unexpected error state");
392   return value->storage.data();
393 }
394 
mlirAsyncRuntimeExecute(CoroHandle handle,CoroResume resume)395 extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
396   auto *runtime = getDefaultAsyncRuntime();
397   runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
398 }
399 
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken * token,CoroHandle handle,CoroResume resume)400 extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
401                                                      CoroHandle handle,
402                                                      CoroResume resume) {
403   auto execute = [handle, resume]() { (*resume)(handle); };
404   std::unique_lock<std::mutex> lock(token->mu);
405   if (State(token->state).isAvailableOrError()) {
406     lock.unlock();
407     execute();
408   } else {
409     token->awaiters.emplace_back([execute]() { execute(); });
410   }
411 }
412 
mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue * value,CoroHandle handle,CoroResume resume)413 extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
414                                                      CoroHandle handle,
415                                                      CoroResume resume) {
416   auto execute = [handle, resume]() { (*resume)(handle); };
417   std::unique_lock<std::mutex> lock(value->mu);
418   if (State(value->state).isAvailableOrError()) {
419     lock.unlock();
420     execute();
421   } else {
422     value->awaiters.emplace_back([execute]() { execute(); });
423   }
424 }
425 
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup * group,CoroHandle handle,CoroResume resume)426 extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
427                                                           CoroHandle handle,
428                                                           CoroResume resume) {
429   auto execute = [handle, resume]() { (*resume)(handle); };
430   std::unique_lock<std::mutex> lock(group->mu);
431   if (group->pendingTokens == 0) {
432     lock.unlock();
433     execute();
434   } else {
435     group->awaiters.emplace_back([execute]() { execute(); });
436   }
437 }
438 
mlirAsyncRuntimGetNumWorkerThreads()439 extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
440   return getDefaultAsyncRuntime()->getThreadPool().getMaxConcurrency();
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // Small async runtime support library for testing.
445 //===----------------------------------------------------------------------===//
446 
mlirAsyncRuntimePrintCurrentThreadId()447 extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
448   static thread_local std::thread::id thisId = std::this_thread::get_id();
449   std::cout << "Current thread id: " << thisId << '\n';
450 }
451 
452 //===----------------------------------------------------------------------===//
453 // MLIR ExecutionEngine dynamic library integration.
454 //===----------------------------------------------------------------------===//
455 
456 // Visual Studio had a bug that fails to compile nested generic lambdas
457 // inside an `extern "C"` function.
458 //   https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
459 // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
460 // a work around for older versions of Visual Studio.
461 // NOLINTNEXTLINE(*-identifier-naming): externally called.
462 extern "C" MLIR_ASYNC_RUNTIME_EXPORT void
463 __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols);
464 
465 // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_execution_engine_init(llvm::StringMap<void * > & exportSymbols)466 void __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols) {
467   auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
468     assert(exportSymbols.count(name) == 0 && "symbol already exists");
469     exportSymbols[name] = reinterpret_cast<void *>(ptr);
470   };
471 
472   exportSymbol("mlirAsyncRuntimeAddRef",
473                &mlir::runtime::mlirAsyncRuntimeAddRef);
474   exportSymbol("mlirAsyncRuntimeDropRef",
475                &mlir::runtime::mlirAsyncRuntimeDropRef);
476   exportSymbol("mlirAsyncRuntimeExecute",
477                &mlir::runtime::mlirAsyncRuntimeExecute);
478   exportSymbol("mlirAsyncRuntimeGetValueStorage",
479                &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
480   exportSymbol("mlirAsyncRuntimeCreateToken",
481                &mlir::runtime::mlirAsyncRuntimeCreateToken);
482   exportSymbol("mlirAsyncRuntimeCreateValue",
483                &mlir::runtime::mlirAsyncRuntimeCreateValue);
484   exportSymbol("mlirAsyncRuntimeEmplaceToken",
485                &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
486   exportSymbol("mlirAsyncRuntimeEmplaceValue",
487                &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
488   exportSymbol("mlirAsyncRuntimeSetTokenError",
489                &mlir::runtime::mlirAsyncRuntimeSetTokenError);
490   exportSymbol("mlirAsyncRuntimeSetValueError",
491                &mlir::runtime::mlirAsyncRuntimeSetValueError);
492   exportSymbol("mlirAsyncRuntimeIsTokenError",
493                &mlir::runtime::mlirAsyncRuntimeIsTokenError);
494   exportSymbol("mlirAsyncRuntimeIsValueError",
495                &mlir::runtime::mlirAsyncRuntimeIsValueError);
496   exportSymbol("mlirAsyncRuntimeIsGroupError",
497                &mlir::runtime::mlirAsyncRuntimeIsGroupError);
498   exportSymbol("mlirAsyncRuntimeAwaitToken",
499                &mlir::runtime::mlirAsyncRuntimeAwaitToken);
500   exportSymbol("mlirAsyncRuntimeAwaitValue",
501                &mlir::runtime::mlirAsyncRuntimeAwaitValue);
502   exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
503                &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
504   exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
505                &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
506   exportSymbol("mlirAsyncRuntimeCreateGroup",
507                &mlir::runtime::mlirAsyncRuntimeCreateGroup);
508   exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
509                &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
510   exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
511                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
512   exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
513                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
514   exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
515                &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
516   exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
517                &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
518 }
519 
520 // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_execution_engine_destroy()521 extern "C" MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_destroy() {
522   resetDefaultAsyncRuntime();
523 }
524 
525 } // namespace runtime
526 } // namespace mlir
527