xref: /llvm-project/mlir/lib/ExecutionEngine/AsyncRuntime.cpp (revision 716042a63f26cd020eb72960f72fa97b9a197382)
136ce915aSLei Zhang //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
236ce915aSLei Zhang //
336ce915aSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
436ce915aSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
536ce915aSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
636ce915aSLei Zhang //
736ce915aSLei Zhang //===----------------------------------------------------------------------===//
836ce915aSLei Zhang //
936ce915aSLei Zhang // This file implements basic Async runtime API for supporting Async dialect
1036ce915aSLei Zhang // to LLVM dialect lowering.
1136ce915aSLei Zhang //
1236ce915aSLei Zhang //===----------------------------------------------------------------------===//
1336ce915aSLei Zhang 
1436ce915aSLei Zhang #include "mlir/ExecutionEngine/AsyncRuntime.h"
1536ce915aSLei Zhang 
16c30ab6c2SEugene Zhulenev #include <atomic>
17a86a9b5eSEugene Zhulenev #include <cassert>
1836ce915aSLei Zhang #include <condition_variable>
1936ce915aSLei Zhang #include <functional>
2036ce915aSLei Zhang #include <iostream>
2136ce915aSLei Zhang #include <mutex>
2236ce915aSLei Zhang #include <thread>
2336ce915aSLei Zhang #include <vector>
2436ce915aSLei Zhang 
251fc98642SEugene Zhulenev #include "llvm/ADT/StringMap.h"
26bb0e6213SEugene Zhulenev #include "llvm/Support/ThreadPool.h"
271fc98642SEugene Zhulenev 
2811f1027bSEugene Zhulenev using namespace mlir::runtime;
2911f1027bSEugene Zhulenev 
3036ce915aSLei Zhang //===----------------------------------------------------------------------===//
3136ce915aSLei Zhang // Async runtime API.
3236ce915aSLei Zhang //===----------------------------------------------------------------------===//
3336ce915aSLei Zhang 
3411f1027bSEugene Zhulenev namespace mlir {
3511f1027bSEugene Zhulenev namespace runtime {
36a86a9b5eSEugene Zhulenev namespace {
37a86a9b5eSEugene Zhulenev 
38a86a9b5eSEugene Zhulenev // Forward declare class defined below.
39a86a9b5eSEugene Zhulenev class RefCounted;
40a86a9b5eSEugene Zhulenev 
41a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
42a86a9b5eSEugene Zhulenev // AsyncRuntime orchestrates all async operations and Async runtime API is built
43a86a9b5eSEugene Zhulenev // on top of the default runtime instance.
44a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
45a86a9b5eSEugene Zhulenev 
46a86a9b5eSEugene Zhulenev class AsyncRuntime {
47a86a9b5eSEugene Zhulenev public:
AsyncRuntime()48a86a9b5eSEugene Zhulenev   AsyncRuntime() : numRefCountedObjects(0) {}
49a86a9b5eSEugene Zhulenev 
~AsyncRuntime()50a86a9b5eSEugene Zhulenev   ~AsyncRuntime() {
51bb0e6213SEugene Zhulenev     threadPool.wait(); // wait for the completion of all async tasks
52a86a9b5eSEugene Zhulenev     assert(getNumRefCountedObjects() == 0 &&
53a86a9b5eSEugene Zhulenev            "all ref counted objects must be destroyed");
54a86a9b5eSEugene Zhulenev   }
55a86a9b5eSEugene Zhulenev 
getNumRefCountedObjects()5692db09cdSEugene Zhulenev   int64_t getNumRefCountedObjects() {
57a86a9b5eSEugene Zhulenev     return numRefCountedObjects.load(std::memory_order_relaxed);
58a86a9b5eSEugene Zhulenev   }
59a86a9b5eSEugene Zhulenev 
getThreadPool()604a4fb930SMehdi Amini   llvm::ThreadPoolInterface &getThreadPool() { return threadPool; }
61bb0e6213SEugene Zhulenev 
62a86a9b5eSEugene Zhulenev private:
63a86a9b5eSEugene Zhulenev   friend class RefCounted;
64a86a9b5eSEugene Zhulenev 
65a86a9b5eSEugene Zhulenev   // Count the total number of reference counted objects in this instance
66a86a9b5eSEugene Zhulenev   // of an AsyncRuntime. For debugging purposes only.
addNumRefCountedObjects()67a86a9b5eSEugene Zhulenev   void addNumRefCountedObjects() {
68a86a9b5eSEugene Zhulenev     numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
69a86a9b5eSEugene Zhulenev   }
dropNumRefCountedObjects()70a86a9b5eSEugene Zhulenev   void dropNumRefCountedObjects() {
71a86a9b5eSEugene Zhulenev     numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
72a86a9b5eSEugene Zhulenev   }
73a86a9b5eSEugene Zhulenev 
7492db09cdSEugene Zhulenev   std::atomic<int64_t> numRefCountedObjects;
75*716042a6SMehdi Amini   llvm::DefaultThreadPool threadPool;
76a86a9b5eSEugene Zhulenev };
77a86a9b5eSEugene Zhulenev 
78a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
7939957aa4SEugene Zhulenev // A state of the async runtime value (token, value or group).
8039957aa4SEugene Zhulenev // -------------------------------------------------------------------------- //
8139957aa4SEugene Zhulenev 
8239957aa4SEugene Zhulenev class State {
8339957aa4SEugene Zhulenev public:
8439957aa4SEugene Zhulenev   enum StateEnum : int8_t {
8539957aa4SEugene Zhulenev     // The underlying value is not yet available for consumption.
8639957aa4SEugene Zhulenev     kUnavailable = 0,
8739957aa4SEugene Zhulenev     // The underlying value is available for consumption. This state can not
8839957aa4SEugene Zhulenev     // transition to any other state.
8939957aa4SEugene Zhulenev     kAvailable = 1,
9039957aa4SEugene Zhulenev     // This underlying value is available and contains an error. This state can
9139957aa4SEugene Zhulenev     // not transition to any other state.
9239957aa4SEugene Zhulenev     kError = 2,
9339957aa4SEugene Zhulenev   };
9439957aa4SEugene Zhulenev 
State(StateEnum s)9539957aa4SEugene Zhulenev   /* implicit */ State(StateEnum s) : state(s) {}
operator StateEnum()9639957aa4SEugene Zhulenev   /* implicit */ operator StateEnum() { return state; }
9739957aa4SEugene Zhulenev 
isUnavailable() const9839957aa4SEugene Zhulenev   bool isUnavailable() const { return state == kUnavailable; }
isAvailable() const9939957aa4SEugene Zhulenev   bool isAvailable() const { return state == kAvailable; }
isError() const10039957aa4SEugene Zhulenev   bool isError() const { return state == kError; }
isAvailableOrError() const10139957aa4SEugene Zhulenev   bool isAvailableOrError() const { return isAvailable() || isError(); }
10239957aa4SEugene Zhulenev 
debug() const10339957aa4SEugene Zhulenev   const char *debug() const {
10439957aa4SEugene Zhulenev     switch (state) {
10539957aa4SEugene Zhulenev     case kUnavailable:
10639957aa4SEugene Zhulenev       return "unavailable";
10739957aa4SEugene Zhulenev     case kAvailable:
10839957aa4SEugene Zhulenev       return "available";
10939957aa4SEugene Zhulenev     case kError:
11039957aa4SEugene Zhulenev       return "error";
11139957aa4SEugene Zhulenev     }
11239957aa4SEugene Zhulenev   }
11339957aa4SEugene Zhulenev 
11439957aa4SEugene Zhulenev private:
11539957aa4SEugene Zhulenev   StateEnum state;
11639957aa4SEugene Zhulenev };
11739957aa4SEugene Zhulenev 
11839957aa4SEugene Zhulenev // -------------------------------------------------------------------------- //
119a86a9b5eSEugene Zhulenev // A base class for all reference counted objects created by the async runtime.
120a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
121a86a9b5eSEugene Zhulenev 
122a86a9b5eSEugene Zhulenev class RefCounted {
123a86a9b5eSEugene Zhulenev public:
RefCounted(AsyncRuntime * runtime,int64_t refCount=1)12492db09cdSEugene Zhulenev   RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
125a86a9b5eSEugene Zhulenev       : runtime(runtime), refCount(refCount) {
126a86a9b5eSEugene Zhulenev     runtime->addNumRefCountedObjects();
127a86a9b5eSEugene Zhulenev   }
128a86a9b5eSEugene Zhulenev 
~RefCounted()129a86a9b5eSEugene Zhulenev   virtual ~RefCounted() {
130a86a9b5eSEugene Zhulenev     assert(refCount.load() == 0 && "reference count must be zero");
131a86a9b5eSEugene Zhulenev     runtime->dropNumRefCountedObjects();
132a86a9b5eSEugene Zhulenev   }
133a86a9b5eSEugene Zhulenev 
134a86a9b5eSEugene Zhulenev   RefCounted(const RefCounted &) = delete;
135a86a9b5eSEugene Zhulenev   RefCounted &operator=(const RefCounted &) = delete;
136a86a9b5eSEugene Zhulenev 
addRef(int64_t count=1)13792db09cdSEugene Zhulenev   void addRef(int64_t count = 1) { refCount.fetch_add(count); }
138a86a9b5eSEugene Zhulenev 
dropRef(int64_t count=1)13992db09cdSEugene Zhulenev   void dropRef(int64_t count = 1) {
14092db09cdSEugene Zhulenev     int64_t previous = refCount.fetch_sub(count);
141a86a9b5eSEugene Zhulenev     assert(previous >= count && "reference count should not go below zero");
142a86a9b5eSEugene Zhulenev     if (previous == count)
143a86a9b5eSEugene Zhulenev       destroy();
144a86a9b5eSEugene Zhulenev   }
145a86a9b5eSEugene Zhulenev 
146a86a9b5eSEugene Zhulenev protected:
destroy()147a86a9b5eSEugene Zhulenev   virtual void destroy() { delete this; }
148a86a9b5eSEugene Zhulenev 
149a86a9b5eSEugene Zhulenev private:
150a86a9b5eSEugene Zhulenev   AsyncRuntime *runtime;
15192db09cdSEugene Zhulenev   std::atomic<int64_t> refCount;
152a86a9b5eSEugene Zhulenev };
153a86a9b5eSEugene Zhulenev 
154a86a9b5eSEugene Zhulenev } // namespace
155a86a9b5eSEugene Zhulenev 
15611f1027bSEugene Zhulenev // Returns the default per-process instance of an async runtime.
getDefaultAsyncRuntimeInstance()1571fc98642SEugene Zhulenev static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
15811f1027bSEugene Zhulenev   static auto runtime = std::make_unique<AsyncRuntime>();
1591fc98642SEugene Zhulenev   return runtime;
1601fc98642SEugene Zhulenev }
1611fc98642SEugene Zhulenev 
resetDefaultAsyncRuntime()1621fc98642SEugene Zhulenev static void resetDefaultAsyncRuntime() {
1631fc98642SEugene Zhulenev   return getDefaultAsyncRuntimeInstance().reset();
1641fc98642SEugene Zhulenev }
1651fc98642SEugene Zhulenev 
getDefaultAsyncRuntime()1661fc98642SEugene Zhulenev static AsyncRuntime *getDefaultAsyncRuntime() {
1671fc98642SEugene Zhulenev   return getDefaultAsyncRuntimeInstance().get();
16811f1027bSEugene Zhulenev }
16911f1027bSEugene Zhulenev 
170621ad468SEugene Zhulenev // Async token provides a mechanism to signal asynchronous operation completion.
171a86a9b5eSEugene Zhulenev struct AsyncToken : public RefCounted {
172a86a9b5eSEugene Zhulenev   // AsyncToken created with a reference count of 2 because it will be returned
173a86a9b5eSEugene Zhulenev   // to the `async.execute` caller and also will be later on emplaced by the
174a86a9b5eSEugene Zhulenev   // asynchronously executed task. If the caller immediately will drop its
175a86a9b5eSEugene Zhulenev   // reference we must ensure that the token will be alive until the
176a86a9b5eSEugene Zhulenev   // asynchronous operation is completed.
AsyncTokenmlir::runtime::AsyncToken177a2223b09SEugene Zhulenev   AsyncToken(AsyncRuntime *runtime)
17839957aa4SEugene Zhulenev       : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
179a86a9b5eSEugene Zhulenev 
18039957aa4SEugene Zhulenev   std::atomic<State::StateEnum> state;
181a2223b09SEugene Zhulenev 
182a2223b09SEugene Zhulenev   // Pending awaiters are guarded by a mutex.
18336ce915aSLei Zhang   std::mutex mu;
18436ce915aSLei Zhang   std::condition_variable cv;
18536ce915aSLei Zhang   std::vector<std::function<void()>> awaiters;
18636ce915aSLei Zhang };
18736ce915aSLei Zhang 
188621ad468SEugene Zhulenev // Async value provides a mechanism to access the result of asynchronous
189621ad468SEugene Zhulenev // operations. It owns the storage that is used to store/load the value of the
190621ad468SEugene Zhulenev // underlying type, and a flag to signal if the value is ready or not.
191621ad468SEugene Zhulenev struct AsyncValue : public RefCounted {
192621ad468SEugene Zhulenev   // AsyncValue similar to an AsyncToken created with a reference count of 2.
AsyncValuemlir::runtime::AsyncValue19392db09cdSEugene Zhulenev   AsyncValue(AsyncRuntime *runtime, int64_t size)
19439957aa4SEugene Zhulenev       : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
19539957aa4SEugene Zhulenev         storage(size) {}
196621ad468SEugene Zhulenev 
19739957aa4SEugene Zhulenev   std::atomic<State::StateEnum> state;
198621ad468SEugene Zhulenev 
199621ad468SEugene Zhulenev   // Use vector of bytes to store async value payload.
2004109276fSyijiagu   std::vector<std::byte> storage;
201a2223b09SEugene Zhulenev 
202a2223b09SEugene Zhulenev   // Pending awaiters are guarded by a mutex.
203a2223b09SEugene Zhulenev   std::mutex mu;
204a2223b09SEugene Zhulenev   std::condition_variable cv;
205a2223b09SEugene Zhulenev   std::vector<std::function<void()>> awaiters;
206621ad468SEugene Zhulenev };
207621ad468SEugene Zhulenev 
208621ad468SEugene Zhulenev // Async group provides a mechanism to group together multiple async tokens or
209621ad468SEugene Zhulenev // values to await on all of them together (wait for the completion of all
210621ad468SEugene Zhulenev // tokens or values added to the group).
211a86a9b5eSEugene Zhulenev struct AsyncGroup : public RefCounted {
AsyncGroupmlir::runtime::AsyncGroup212d43b2360SEugene Zhulenev   AsyncGroup(AsyncRuntime *runtime, int64_t size)
213d43b2360SEugene Zhulenev       : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
214a86a9b5eSEugene Zhulenev 
215a86a9b5eSEugene Zhulenev   std::atomic<int> pendingTokens;
216d8c84d2aSEugene Zhulenev   std::atomic<int> numErrors;
217a86a9b5eSEugene Zhulenev   std::atomic<int> rank;
218a86a9b5eSEugene Zhulenev 
219a2223b09SEugene Zhulenev   // Pending awaiters are guarded by a mutex.
220c30ab6c2SEugene Zhulenev   std::mutex mu;
221c30ab6c2SEugene Zhulenev   std::condition_variable cv;
222c30ab6c2SEugene Zhulenev   std::vector<std::function<void()>> awaiters;
223c30ab6c2SEugene Zhulenev };
224c30ab6c2SEugene Zhulenev 
225a86a9b5eSEugene Zhulenev // Adds references to reference counted runtime object.
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr,int64_t count)22692db09cdSEugene Zhulenev extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
227a86a9b5eSEugene Zhulenev   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
228a86a9b5eSEugene Zhulenev   refCounted->addRef(count);
229a86a9b5eSEugene Zhulenev }
230a86a9b5eSEugene Zhulenev 
231a86a9b5eSEugene Zhulenev // Drops references from reference counted runtime object.
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr,int64_t count)23292db09cdSEugene Zhulenev extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
233a86a9b5eSEugene Zhulenev   RefCounted *refCounted = static_cast<RefCounted *>(ptr);
234a86a9b5eSEugene Zhulenev   refCounted->dropRef(count);
235a86a9b5eSEugene Zhulenev }
236a86a9b5eSEugene Zhulenev 
237621ad468SEugene Zhulenev // Creates a new `async.token` in not-ready state.
mlirAsyncRuntimeCreateToken()2386fd9e59eSPaul Lietar extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
2391fc98642SEugene Zhulenev   AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
24036ce915aSLei Zhang   return token;
24136ce915aSLei Zhang }
24236ce915aSLei Zhang 
243621ad468SEugene Zhulenev // Creates a new `async.value` in not-ready state.
mlirAsyncRuntimeCreateValue(int64_t size)24492db09cdSEugene Zhulenev extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
2451fc98642SEugene Zhulenev   AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
246621ad468SEugene Zhulenev   return value;
247621ad468SEugene Zhulenev }
248621ad468SEugene Zhulenev 
249c30ab6c2SEugene Zhulenev // Create a new `async.group` in empty state.
mlirAsyncRuntimeCreateGroup(int64_t size)250d43b2360SEugene Zhulenev extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
251d43b2360SEugene Zhulenev   AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
252c30ab6c2SEugene Zhulenev   return group;
253c30ab6c2SEugene Zhulenev }
254c30ab6c2SEugene Zhulenev 
mlirAsyncRuntimeAddTokenToGroup(AsyncToken * token,AsyncGroup * group)2553d95d1b4SEugene Zhulenev extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
2563d95d1b4SEugene Zhulenev                                                    AsyncGroup *group) {
257c30ab6c2SEugene Zhulenev   std::unique_lock<std::mutex> lockToken(token->mu);
258c30ab6c2SEugene Zhulenev   std::unique_lock<std::mutex> lockGroup(group->mu);
259c30ab6c2SEugene Zhulenev 
260a86a9b5eSEugene Zhulenev   // Get the rank of the token inside the group before we drop the reference.
261a86a9b5eSEugene Zhulenev   int rank = group->rank.fetch_add(1);
262c30ab6c2SEugene Zhulenev 
263d8c84d2aSEugene Zhulenev   auto onTokenReady = [group, token]() {
264d8c84d2aSEugene Zhulenev     // Increment the number of errors in the group.
265d8c84d2aSEugene Zhulenev     if (State(token->state).isError())
266d8c84d2aSEugene Zhulenev       group->numErrors.fetch_add(1);
267d8c84d2aSEugene Zhulenev 
268d43b2360SEugene Zhulenev     // If pending tokens go below zero it means that more tokens than the group
269d43b2360SEugene Zhulenev     // size were added to this group.
270d43b2360SEugene Zhulenev     assert(group->pendingTokens > 0 && "wrong group size");
271d43b2360SEugene Zhulenev 
272c30ab6c2SEugene Zhulenev     // Run all group awaiters if it was the last token in the group.
273c30ab6c2SEugene Zhulenev     if (group->pendingTokens.fetch_sub(1) == 1) {
274c30ab6c2SEugene Zhulenev       group->cv.notify_all();
275c30ab6c2SEugene Zhulenev       for (auto &awaiter : group->awaiters)
276c30ab6c2SEugene Zhulenev         awaiter();
277c30ab6c2SEugene Zhulenev     }
278c30ab6c2SEugene Zhulenev   };
279c30ab6c2SEugene Zhulenev 
28039957aa4SEugene Zhulenev   if (State(token->state).isAvailableOrError()) {
2813d95d1b4SEugene Zhulenev     // Update group pending tokens immediately and maybe run awaiters.
2823d95d1b4SEugene Zhulenev     onTokenReady();
2833d95d1b4SEugene Zhulenev 
284a86a9b5eSEugene Zhulenev   } else {
2853d95d1b4SEugene Zhulenev     // Update group pending tokens when token will become ready. Because this
2863d95d1b4SEugene Zhulenev     // will happen asynchronously we must ensure that `group` is alive until
2873d95d1b4SEugene Zhulenev     // then, and re-ackquire the lock.
288a86a9b5eSEugene Zhulenev     group->addRef();
2893d95d1b4SEugene Zhulenev 
290e5639b3fSMehdi Amini     token->awaiters.emplace_back([group, onTokenReady]() {
2913d95d1b4SEugene Zhulenev       // Make sure that `dropRef` does not destroy the mutex owned by the lock.
2923d95d1b4SEugene Zhulenev       {
2933d95d1b4SEugene Zhulenev         std::unique_lock<std::mutex> lockGroup(group->mu);
2943d95d1b4SEugene Zhulenev         onTokenReady();
2953d95d1b4SEugene Zhulenev       }
2963d95d1b4SEugene Zhulenev       group->dropRef();
2973d95d1b4SEugene Zhulenev     });
298a86a9b5eSEugene Zhulenev   }
299c30ab6c2SEugene Zhulenev 
300a86a9b5eSEugene Zhulenev   return rank;
301c30ab6c2SEugene Zhulenev }
302c30ab6c2SEugene Zhulenev 
30339957aa4SEugene Zhulenev // Switches `async.token` to available or error state (terminatl state) and runs
30439957aa4SEugene Zhulenev // all awaiters.
setTokenState(AsyncToken * token,State state)30539957aa4SEugene Zhulenev static void setTokenState(AsyncToken *token, State state) {
30639957aa4SEugene Zhulenev   assert(state.isAvailableOrError() && "must be terminal state");
30739957aa4SEugene Zhulenev   assert(State(token->state).isUnavailable() && "token must be unavailable");
30839957aa4SEugene Zhulenev 
3093d95d1b4SEugene Zhulenev   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
3103d95d1b4SEugene Zhulenev   {
31136ce915aSLei Zhang     std::unique_lock<std::mutex> lock(token->mu);
31239957aa4SEugene Zhulenev     token->state = state;
31336ce915aSLei Zhang     token->cv.notify_all();
31436ce915aSLei Zhang     for (auto &awaiter : token->awaiters)
31536ce915aSLei Zhang       awaiter();
3163d95d1b4SEugene Zhulenev   }
317a86a9b5eSEugene Zhulenev 
318a86a9b5eSEugene Zhulenev   // Async tokens created with a ref count `2` to keep token alive until the
319a86a9b5eSEugene Zhulenev   // async task completes. Drop this reference explicitly when token emplaced.
320a86a9b5eSEugene Zhulenev   token->dropRef();
32136ce915aSLei Zhang }
32236ce915aSLei Zhang 
setValueState(AsyncValue * value,State state)32339957aa4SEugene Zhulenev static void setValueState(AsyncValue *value, State state) {
32439957aa4SEugene Zhulenev   assert(state.isAvailableOrError() && "must be terminal state");
32539957aa4SEugene Zhulenev   assert(State(value->state).isUnavailable() && "value must be unavailable");
32639957aa4SEugene Zhulenev 
327621ad468SEugene Zhulenev   // Make sure that `dropRef` does not destroy the mutex owned by the lock.
328621ad468SEugene Zhulenev   {
329621ad468SEugene Zhulenev     std::unique_lock<std::mutex> lock(value->mu);
33039957aa4SEugene Zhulenev     value->state = state;
331621ad468SEugene Zhulenev     value->cv.notify_all();
332621ad468SEugene Zhulenev     for (auto &awaiter : value->awaiters)
333621ad468SEugene Zhulenev       awaiter();
334621ad468SEugene Zhulenev   }
335621ad468SEugene Zhulenev 
336621ad468SEugene Zhulenev   // Async values created with a ref count `2` to keep value alive until the
337621ad468SEugene Zhulenev   // async task completes. Drop this reference explicitly when value emplaced.
338621ad468SEugene Zhulenev   value->dropRef();
339621ad468SEugene Zhulenev }
340621ad468SEugene Zhulenev 
mlirAsyncRuntimeEmplaceToken(AsyncToken * token)34139957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
34239957aa4SEugene Zhulenev   setTokenState(token, State::kAvailable);
34339957aa4SEugene Zhulenev }
34439957aa4SEugene Zhulenev 
mlirAsyncRuntimeEmplaceValue(AsyncValue * value)34539957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
34639957aa4SEugene Zhulenev   setValueState(value, State::kAvailable);
34739957aa4SEugene Zhulenev }
34839957aa4SEugene Zhulenev 
mlirAsyncRuntimeSetTokenError(AsyncToken * token)34939957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
35039957aa4SEugene Zhulenev   setTokenState(token, State::kError);
35139957aa4SEugene Zhulenev }
35239957aa4SEugene Zhulenev 
mlirAsyncRuntimeSetValueError(AsyncValue * value)35339957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
35439957aa4SEugene Zhulenev   setValueState(value, State::kError);
35539957aa4SEugene Zhulenev }
35639957aa4SEugene Zhulenev 
mlirAsyncRuntimeIsTokenError(AsyncToken * token)35739957aa4SEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
35839957aa4SEugene Zhulenev   return State(token->state).isError();
35939957aa4SEugene Zhulenev }
36039957aa4SEugene Zhulenev 
mlirAsyncRuntimeIsValueError(AsyncValue * value)36139957aa4SEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
36239957aa4SEugene Zhulenev   return State(value->state).isError();
36339957aa4SEugene Zhulenev }
36439957aa4SEugene Zhulenev 
mlirAsyncRuntimeIsGroupError(AsyncGroup * group)365d8c84d2aSEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
366d8c84d2aSEugene Zhulenev   return group->numErrors.load() > 0;
367d8c84d2aSEugene Zhulenev }
368d8c84d2aSEugene Zhulenev 
mlirAsyncRuntimeAwaitToken(AsyncToken * token)3696fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
37036ce915aSLei Zhang   std::unique_lock<std::mutex> lock(token->mu);
37139957aa4SEugene Zhulenev   if (!State(token->state).isAvailableOrError())
37239957aa4SEugene Zhulenev     token->cv.wait(
37339957aa4SEugene Zhulenev         lock, [token] { return State(token->state).isAvailableOrError(); });
374c30ab6c2SEugene Zhulenev }
375c30ab6c2SEugene Zhulenev 
mlirAsyncRuntimeAwaitValue(AsyncValue * value)376621ad468SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
377621ad468SEugene Zhulenev   std::unique_lock<std::mutex> lock(value->mu);
37839957aa4SEugene Zhulenev   if (!State(value->state).isAvailableOrError())
37939957aa4SEugene Zhulenev     value->cv.wait(
38039957aa4SEugene Zhulenev         lock, [value] { return State(value->state).isAvailableOrError(); });
381621ad468SEugene Zhulenev }
382621ad468SEugene Zhulenev 
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup * group)3833d95d1b4SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
384c30ab6c2SEugene Zhulenev   std::unique_lock<std::mutex> lock(group->mu);
385c30ab6c2SEugene Zhulenev   if (group->pendingTokens != 0)
386c30ab6c2SEugene Zhulenev     group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
38736ce915aSLei Zhang }
38836ce915aSLei Zhang 
389621ad468SEugene Zhulenev // Returns a pointer to the storage owned by the async value.
mlirAsyncRuntimeGetValueStorage(AsyncValue * value)390621ad468SEugene Zhulenev extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
39139957aa4SEugene Zhulenev   assert(!State(value->state).isError() && "unexpected error state");
392621ad468SEugene Zhulenev   return value->storage.data();
393621ad468SEugene Zhulenev }
394621ad468SEugene Zhulenev 
mlirAsyncRuntimeExecute(CoroHandle handle,CoroResume resume)3956fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
396bb0e6213SEugene Zhulenev   auto *runtime = getDefaultAsyncRuntime();
397bb0e6213SEugene Zhulenev   runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
39836ce915aSLei Zhang }
39936ce915aSLei Zhang 
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken * token,CoroHandle handle,CoroResume resume)4006fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
4016fd9e59eSPaul Lietar                                                      CoroHandle handle,
40236ce915aSLei Zhang                                                      CoroResume resume) {
4033d95d1b4SEugene Zhulenev   auto execute = [handle, resume]() { (*resume)(handle); };
404f63f28edSEugene Zhulenev   std::unique_lock<std::mutex> lock(token->mu);
40539957aa4SEugene Zhulenev   if (State(token->state).isAvailableOrError()) {
406f63f28edSEugene Zhulenev     lock.unlock();
4073d95d1b4SEugene Zhulenev     execute();
408a2223b09SEugene Zhulenev   } else {
409e5639b3fSMehdi Amini     token->awaiters.emplace_back([execute]() { execute(); });
41036ce915aSLei Zhang   }
411a2223b09SEugene Zhulenev }
41236ce915aSLei Zhang 
mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue * value,CoroHandle handle,CoroResume resume)413621ad468SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
414621ad468SEugene Zhulenev                                                      CoroHandle handle,
415621ad468SEugene Zhulenev                                                      CoroResume resume) {
416621ad468SEugene Zhulenev   auto execute = [handle, resume]() { (*resume)(handle); };
417f63f28edSEugene Zhulenev   std::unique_lock<std::mutex> lock(value->mu);
41839957aa4SEugene Zhulenev   if (State(value->state).isAvailableOrError()) {
419f63f28edSEugene Zhulenev     lock.unlock();
420621ad468SEugene Zhulenev     execute();
421a2223b09SEugene Zhulenev   } else {
422e5639b3fSMehdi Amini     value->awaiters.emplace_back([execute]() { execute(); });
423621ad468SEugene Zhulenev   }
424a2223b09SEugene Zhulenev }
425621ad468SEugene Zhulenev 
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup * group,CoroHandle handle,CoroResume resume)4263d95d1b4SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
4273d95d1b4SEugene Zhulenev                                                           CoroHandle handle,
428c30ab6c2SEugene Zhulenev                                                           CoroResume resume) {
4293d95d1b4SEugene Zhulenev   auto execute = [handle, resume]() { (*resume)(handle); };
430f63f28edSEugene Zhulenev   std::unique_lock<std::mutex> lock(group->mu);
431a2223b09SEugene Zhulenev   if (group->pendingTokens == 0) {
432f63f28edSEugene Zhulenev     lock.unlock();
4333d95d1b4SEugene Zhulenev     execute();
434a2223b09SEugene Zhulenev   } else {
435e5639b3fSMehdi Amini     group->awaiters.emplace_back([execute]() { execute(); });
436c30ab6c2SEugene Zhulenev   }
437a2223b09SEugene Zhulenev }
438c30ab6c2SEugene Zhulenev 
mlirAsyncRuntimGetNumWorkerThreads()439149311b4Sbakhtiyar extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
440744616b3SMehdi Amini   return getDefaultAsyncRuntime()->getThreadPool().getMaxConcurrency();
441149311b4Sbakhtiyar }
442149311b4Sbakhtiyar 
44336ce915aSLei Zhang //===----------------------------------------------------------------------===//
44436ce915aSLei Zhang // Small async runtime support library for testing.
44536ce915aSLei Zhang //===----------------------------------------------------------------------===//
44636ce915aSLei Zhang 
mlirAsyncRuntimePrintCurrentThreadId()4476fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
44836ce915aSLei Zhang   static thread_local std::thread::id thisId = std::this_thread::get_id();
449ac8b53fcSAdrian Kuegel   std::cout << "Current thread id: " << thisId << '\n';
45036ce915aSLei Zhang }
45136ce915aSLei Zhang 
4521fc98642SEugene Zhulenev //===----------------------------------------------------------------------===//
453530db6a3SIngo Müller // MLIR ExecutionEngine dynamic library integration.
4541fc98642SEugene Zhulenev //===----------------------------------------------------------------------===//
4551fc98642SEugene Zhulenev 
456dd2dac2fSMatthew Parkinson // Visual Studio had a bug that fails to compile nested generic lambdas
457dd2dac2fSMatthew Parkinson // inside an `extern "C"` function.
458dd2dac2fSMatthew Parkinson //   https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
459dd2dac2fSMatthew Parkinson // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
460dd2dac2fSMatthew Parkinson // a work around for older versions of Visual Studio.
46102b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
462f9bce19eSIngo Müller extern "C" MLIR_ASYNC_RUNTIME_EXPORT void
4630b3841ebSIngo Müller __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols);
464dd2dac2fSMatthew Parkinson 
46502b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_execution_engine_init(llvm::StringMap<void * > & exportSymbols)4660b3841ebSIngo Müller void __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols) {
4671fc98642SEugene Zhulenev   auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
4681fc98642SEugene Zhulenev     assert(exportSymbols.count(name) == 0 && "symbol already exists");
4691fc98642SEugene Zhulenev     exportSymbols[name] = reinterpret_cast<void *>(ptr);
4701fc98642SEugene Zhulenev   };
4711fc98642SEugene Zhulenev 
4721fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAddRef",
4731fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAddRef);
4741fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeDropRef",
4751fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeDropRef);
4761fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeExecute",
4771fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeExecute);
4781fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeGetValueStorage",
4791fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
4801fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeCreateToken",
4811fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeCreateToken);
4821fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeCreateValue",
4831fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeCreateValue);
4841fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeEmplaceToken",
4851fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
4861fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeEmplaceValue",
4871fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
48839957aa4SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeSetTokenError",
48939957aa4SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeSetTokenError);
49039957aa4SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeSetValueError",
49139957aa4SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeSetValueError);
49239957aa4SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeIsTokenError",
49339957aa4SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeIsTokenError);
49439957aa4SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeIsValueError",
49539957aa4SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeIsValueError);
496d8c84d2aSEugene Zhulenev   exportSymbol("mlirAsyncRuntimeIsGroupError",
497d8c84d2aSEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeIsGroupError);
4981fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitToken",
4991fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitToken);
5001fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitValue",
5011fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitValue);
5021fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
5031fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
5041fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
5051fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
5061fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeCreateGroup",
5071fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeCreateGroup);
5081fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
5091fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
5101fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
5111fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
5121fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
5131fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
514149311b4Sbakhtiyar   exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
515149311b4Sbakhtiyar                &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
5161fc98642SEugene Zhulenev   exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
5171fc98642SEugene Zhulenev                &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
5181fc98642SEugene Zhulenev }
5191fc98642SEugene Zhulenev 
52002b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_execution_engine_destroy()521f9bce19eSIngo Müller extern "C" MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_destroy() {
5220b3841ebSIngo Müller   resetDefaultAsyncRuntime();
5230b3841ebSIngo Müller }
5241fc98642SEugene Zhulenev 
525dd2dac2fSMatthew Parkinson } // namespace runtime
526dd2dac2fSMatthew Parkinson } // namespace mlir
527