136ce915aSLei Zhang //===- AsyncRuntime.cpp - Async runtime reference implementation ----------===//
236ce915aSLei Zhang //
336ce915aSLei Zhang // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
436ce915aSLei Zhang // See https://llvm.org/LICENSE.txt for license information.
536ce915aSLei Zhang // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
636ce915aSLei Zhang //
736ce915aSLei Zhang //===----------------------------------------------------------------------===//
836ce915aSLei Zhang //
936ce915aSLei Zhang // This file implements basic Async runtime API for supporting Async dialect
1036ce915aSLei Zhang // to LLVM dialect lowering.
1136ce915aSLei Zhang //
1236ce915aSLei Zhang //===----------------------------------------------------------------------===//
1336ce915aSLei Zhang
1436ce915aSLei Zhang #include "mlir/ExecutionEngine/AsyncRuntime.h"
1536ce915aSLei Zhang
16c30ab6c2SEugene Zhulenev #include <atomic>
17a86a9b5eSEugene Zhulenev #include <cassert>
1836ce915aSLei Zhang #include <condition_variable>
1936ce915aSLei Zhang #include <functional>
2036ce915aSLei Zhang #include <iostream>
2136ce915aSLei Zhang #include <mutex>
2236ce915aSLei Zhang #include <thread>
2336ce915aSLei Zhang #include <vector>
2436ce915aSLei Zhang
251fc98642SEugene Zhulenev #include "llvm/ADT/StringMap.h"
26bb0e6213SEugene Zhulenev #include "llvm/Support/ThreadPool.h"
271fc98642SEugene Zhulenev
2811f1027bSEugene Zhulenev using namespace mlir::runtime;
2911f1027bSEugene Zhulenev
3036ce915aSLei Zhang //===----------------------------------------------------------------------===//
3136ce915aSLei Zhang // Async runtime API.
3236ce915aSLei Zhang //===----------------------------------------------------------------------===//
3336ce915aSLei Zhang
3411f1027bSEugene Zhulenev namespace mlir {
3511f1027bSEugene Zhulenev namespace runtime {
36a86a9b5eSEugene Zhulenev namespace {
37a86a9b5eSEugene Zhulenev
38a86a9b5eSEugene Zhulenev // Forward declare class defined below.
39a86a9b5eSEugene Zhulenev class RefCounted;
40a86a9b5eSEugene Zhulenev
41a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
42a86a9b5eSEugene Zhulenev // AsyncRuntime orchestrates all async operations and Async runtime API is built
43a86a9b5eSEugene Zhulenev // on top of the default runtime instance.
44a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
45a86a9b5eSEugene Zhulenev
46a86a9b5eSEugene Zhulenev class AsyncRuntime {
47a86a9b5eSEugene Zhulenev public:
AsyncRuntime()48a86a9b5eSEugene Zhulenev AsyncRuntime() : numRefCountedObjects(0) {}
49a86a9b5eSEugene Zhulenev
~AsyncRuntime()50a86a9b5eSEugene Zhulenev ~AsyncRuntime() {
51bb0e6213SEugene Zhulenev threadPool.wait(); // wait for the completion of all async tasks
52a86a9b5eSEugene Zhulenev assert(getNumRefCountedObjects() == 0 &&
53a86a9b5eSEugene Zhulenev "all ref counted objects must be destroyed");
54a86a9b5eSEugene Zhulenev }
55a86a9b5eSEugene Zhulenev
getNumRefCountedObjects()5692db09cdSEugene Zhulenev int64_t getNumRefCountedObjects() {
57a86a9b5eSEugene Zhulenev return numRefCountedObjects.load(std::memory_order_relaxed);
58a86a9b5eSEugene Zhulenev }
59a86a9b5eSEugene Zhulenev
getThreadPool()604a4fb930SMehdi Amini llvm::ThreadPoolInterface &getThreadPool() { return threadPool; }
61bb0e6213SEugene Zhulenev
62a86a9b5eSEugene Zhulenev private:
63a86a9b5eSEugene Zhulenev friend class RefCounted;
64a86a9b5eSEugene Zhulenev
65a86a9b5eSEugene Zhulenev // Count the total number of reference counted objects in this instance
66a86a9b5eSEugene Zhulenev // of an AsyncRuntime. For debugging purposes only.
addNumRefCountedObjects()67a86a9b5eSEugene Zhulenev void addNumRefCountedObjects() {
68a86a9b5eSEugene Zhulenev numRefCountedObjects.fetch_add(1, std::memory_order_relaxed);
69a86a9b5eSEugene Zhulenev }
dropNumRefCountedObjects()70a86a9b5eSEugene Zhulenev void dropNumRefCountedObjects() {
71a86a9b5eSEugene Zhulenev numRefCountedObjects.fetch_sub(1, std::memory_order_relaxed);
72a86a9b5eSEugene Zhulenev }
73a86a9b5eSEugene Zhulenev
7492db09cdSEugene Zhulenev std::atomic<int64_t> numRefCountedObjects;
75*716042a6SMehdi Amini llvm::DefaultThreadPool threadPool;
76a86a9b5eSEugene Zhulenev };
77a86a9b5eSEugene Zhulenev
78a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
7939957aa4SEugene Zhulenev // A state of the async runtime value (token, value or group).
8039957aa4SEugene Zhulenev // -------------------------------------------------------------------------- //
8139957aa4SEugene Zhulenev
8239957aa4SEugene Zhulenev class State {
8339957aa4SEugene Zhulenev public:
8439957aa4SEugene Zhulenev enum StateEnum : int8_t {
8539957aa4SEugene Zhulenev // The underlying value is not yet available for consumption.
8639957aa4SEugene Zhulenev kUnavailable = 0,
8739957aa4SEugene Zhulenev // The underlying value is available for consumption. This state can not
8839957aa4SEugene Zhulenev // transition to any other state.
8939957aa4SEugene Zhulenev kAvailable = 1,
9039957aa4SEugene Zhulenev // This underlying value is available and contains an error. This state can
9139957aa4SEugene Zhulenev // not transition to any other state.
9239957aa4SEugene Zhulenev kError = 2,
9339957aa4SEugene Zhulenev };
9439957aa4SEugene Zhulenev
State(StateEnum s)9539957aa4SEugene Zhulenev /* implicit */ State(StateEnum s) : state(s) {}
operator StateEnum()9639957aa4SEugene Zhulenev /* implicit */ operator StateEnum() { return state; }
9739957aa4SEugene Zhulenev
isUnavailable() const9839957aa4SEugene Zhulenev bool isUnavailable() const { return state == kUnavailable; }
isAvailable() const9939957aa4SEugene Zhulenev bool isAvailable() const { return state == kAvailable; }
isError() const10039957aa4SEugene Zhulenev bool isError() const { return state == kError; }
isAvailableOrError() const10139957aa4SEugene Zhulenev bool isAvailableOrError() const { return isAvailable() || isError(); }
10239957aa4SEugene Zhulenev
debug() const10339957aa4SEugene Zhulenev const char *debug() const {
10439957aa4SEugene Zhulenev switch (state) {
10539957aa4SEugene Zhulenev case kUnavailable:
10639957aa4SEugene Zhulenev return "unavailable";
10739957aa4SEugene Zhulenev case kAvailable:
10839957aa4SEugene Zhulenev return "available";
10939957aa4SEugene Zhulenev case kError:
11039957aa4SEugene Zhulenev return "error";
11139957aa4SEugene Zhulenev }
11239957aa4SEugene Zhulenev }
11339957aa4SEugene Zhulenev
11439957aa4SEugene Zhulenev private:
11539957aa4SEugene Zhulenev StateEnum state;
11639957aa4SEugene Zhulenev };
11739957aa4SEugene Zhulenev
11839957aa4SEugene Zhulenev // -------------------------------------------------------------------------- //
119a86a9b5eSEugene Zhulenev // A base class for all reference counted objects created by the async runtime.
120a86a9b5eSEugene Zhulenev // -------------------------------------------------------------------------- //
121a86a9b5eSEugene Zhulenev
122a86a9b5eSEugene Zhulenev class RefCounted {
123a86a9b5eSEugene Zhulenev public:
RefCounted(AsyncRuntime * runtime,int64_t refCount=1)12492db09cdSEugene Zhulenev RefCounted(AsyncRuntime *runtime, int64_t refCount = 1)
125a86a9b5eSEugene Zhulenev : runtime(runtime), refCount(refCount) {
126a86a9b5eSEugene Zhulenev runtime->addNumRefCountedObjects();
127a86a9b5eSEugene Zhulenev }
128a86a9b5eSEugene Zhulenev
~RefCounted()129a86a9b5eSEugene Zhulenev virtual ~RefCounted() {
130a86a9b5eSEugene Zhulenev assert(refCount.load() == 0 && "reference count must be zero");
131a86a9b5eSEugene Zhulenev runtime->dropNumRefCountedObjects();
132a86a9b5eSEugene Zhulenev }
133a86a9b5eSEugene Zhulenev
134a86a9b5eSEugene Zhulenev RefCounted(const RefCounted &) = delete;
135a86a9b5eSEugene Zhulenev RefCounted &operator=(const RefCounted &) = delete;
136a86a9b5eSEugene Zhulenev
addRef(int64_t count=1)13792db09cdSEugene Zhulenev void addRef(int64_t count = 1) { refCount.fetch_add(count); }
138a86a9b5eSEugene Zhulenev
dropRef(int64_t count=1)13992db09cdSEugene Zhulenev void dropRef(int64_t count = 1) {
14092db09cdSEugene Zhulenev int64_t previous = refCount.fetch_sub(count);
141a86a9b5eSEugene Zhulenev assert(previous >= count && "reference count should not go below zero");
142a86a9b5eSEugene Zhulenev if (previous == count)
143a86a9b5eSEugene Zhulenev destroy();
144a86a9b5eSEugene Zhulenev }
145a86a9b5eSEugene Zhulenev
146a86a9b5eSEugene Zhulenev protected:
destroy()147a86a9b5eSEugene Zhulenev virtual void destroy() { delete this; }
148a86a9b5eSEugene Zhulenev
149a86a9b5eSEugene Zhulenev private:
150a86a9b5eSEugene Zhulenev AsyncRuntime *runtime;
15192db09cdSEugene Zhulenev std::atomic<int64_t> refCount;
152a86a9b5eSEugene Zhulenev };
153a86a9b5eSEugene Zhulenev
154a86a9b5eSEugene Zhulenev } // namespace
155a86a9b5eSEugene Zhulenev
15611f1027bSEugene Zhulenev // Returns the default per-process instance of an async runtime.
getDefaultAsyncRuntimeInstance()1571fc98642SEugene Zhulenev static std::unique_ptr<AsyncRuntime> &getDefaultAsyncRuntimeInstance() {
15811f1027bSEugene Zhulenev static auto runtime = std::make_unique<AsyncRuntime>();
1591fc98642SEugene Zhulenev return runtime;
1601fc98642SEugene Zhulenev }
1611fc98642SEugene Zhulenev
resetDefaultAsyncRuntime()1621fc98642SEugene Zhulenev static void resetDefaultAsyncRuntime() {
1631fc98642SEugene Zhulenev return getDefaultAsyncRuntimeInstance().reset();
1641fc98642SEugene Zhulenev }
1651fc98642SEugene Zhulenev
getDefaultAsyncRuntime()1661fc98642SEugene Zhulenev static AsyncRuntime *getDefaultAsyncRuntime() {
1671fc98642SEugene Zhulenev return getDefaultAsyncRuntimeInstance().get();
16811f1027bSEugene Zhulenev }
16911f1027bSEugene Zhulenev
170621ad468SEugene Zhulenev // Async token provides a mechanism to signal asynchronous operation completion.
171a86a9b5eSEugene Zhulenev struct AsyncToken : public RefCounted {
172a86a9b5eSEugene Zhulenev // AsyncToken created with a reference count of 2 because it will be returned
173a86a9b5eSEugene Zhulenev // to the `async.execute` caller and also will be later on emplaced by the
174a86a9b5eSEugene Zhulenev // asynchronously executed task. If the caller immediately will drop its
175a86a9b5eSEugene Zhulenev // reference we must ensure that the token will be alive until the
176a86a9b5eSEugene Zhulenev // asynchronous operation is completed.
AsyncTokenmlir::runtime::AsyncToken177a2223b09SEugene Zhulenev AsyncToken(AsyncRuntime *runtime)
17839957aa4SEugene Zhulenev : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable) {}
179a86a9b5eSEugene Zhulenev
18039957aa4SEugene Zhulenev std::atomic<State::StateEnum> state;
181a2223b09SEugene Zhulenev
182a2223b09SEugene Zhulenev // Pending awaiters are guarded by a mutex.
18336ce915aSLei Zhang std::mutex mu;
18436ce915aSLei Zhang std::condition_variable cv;
18536ce915aSLei Zhang std::vector<std::function<void()>> awaiters;
18636ce915aSLei Zhang };
18736ce915aSLei Zhang
188621ad468SEugene Zhulenev // Async value provides a mechanism to access the result of asynchronous
189621ad468SEugene Zhulenev // operations. It owns the storage that is used to store/load the value of the
190621ad468SEugene Zhulenev // underlying type, and a flag to signal if the value is ready or not.
191621ad468SEugene Zhulenev struct AsyncValue : public RefCounted {
192621ad468SEugene Zhulenev // AsyncValue similar to an AsyncToken created with a reference count of 2.
AsyncValuemlir::runtime::AsyncValue19392db09cdSEugene Zhulenev AsyncValue(AsyncRuntime *runtime, int64_t size)
19439957aa4SEugene Zhulenev : RefCounted(runtime, /*refCount=*/2), state(State::kUnavailable),
19539957aa4SEugene Zhulenev storage(size) {}
196621ad468SEugene Zhulenev
19739957aa4SEugene Zhulenev std::atomic<State::StateEnum> state;
198621ad468SEugene Zhulenev
199621ad468SEugene Zhulenev // Use vector of bytes to store async value payload.
2004109276fSyijiagu std::vector<std::byte> storage;
201a2223b09SEugene Zhulenev
202a2223b09SEugene Zhulenev // Pending awaiters are guarded by a mutex.
203a2223b09SEugene Zhulenev std::mutex mu;
204a2223b09SEugene Zhulenev std::condition_variable cv;
205a2223b09SEugene Zhulenev std::vector<std::function<void()>> awaiters;
206621ad468SEugene Zhulenev };
207621ad468SEugene Zhulenev
208621ad468SEugene Zhulenev // Async group provides a mechanism to group together multiple async tokens or
209621ad468SEugene Zhulenev // values to await on all of them together (wait for the completion of all
210621ad468SEugene Zhulenev // tokens or values added to the group).
211a86a9b5eSEugene Zhulenev struct AsyncGroup : public RefCounted {
AsyncGroupmlir::runtime::AsyncGroup212d43b2360SEugene Zhulenev AsyncGroup(AsyncRuntime *runtime, int64_t size)
213d43b2360SEugene Zhulenev : RefCounted(runtime), pendingTokens(size), numErrors(0), rank(0) {}
214a86a9b5eSEugene Zhulenev
215a86a9b5eSEugene Zhulenev std::atomic<int> pendingTokens;
216d8c84d2aSEugene Zhulenev std::atomic<int> numErrors;
217a86a9b5eSEugene Zhulenev std::atomic<int> rank;
218a86a9b5eSEugene Zhulenev
219a2223b09SEugene Zhulenev // Pending awaiters are guarded by a mutex.
220c30ab6c2SEugene Zhulenev std::mutex mu;
221c30ab6c2SEugene Zhulenev std::condition_variable cv;
222c30ab6c2SEugene Zhulenev std::vector<std::function<void()>> awaiters;
223c30ab6c2SEugene Zhulenev };
224c30ab6c2SEugene Zhulenev
225a86a9b5eSEugene Zhulenev // Adds references to reference counted runtime object.
mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr,int64_t count)22692db09cdSEugene Zhulenev extern "C" void mlirAsyncRuntimeAddRef(RefCountedObjPtr ptr, int64_t count) {
227a86a9b5eSEugene Zhulenev RefCounted *refCounted = static_cast<RefCounted *>(ptr);
228a86a9b5eSEugene Zhulenev refCounted->addRef(count);
229a86a9b5eSEugene Zhulenev }
230a86a9b5eSEugene Zhulenev
231a86a9b5eSEugene Zhulenev // Drops references from reference counted runtime object.
mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr,int64_t count)23292db09cdSEugene Zhulenev extern "C" void mlirAsyncRuntimeDropRef(RefCountedObjPtr ptr, int64_t count) {
233a86a9b5eSEugene Zhulenev RefCounted *refCounted = static_cast<RefCounted *>(ptr);
234a86a9b5eSEugene Zhulenev refCounted->dropRef(count);
235a86a9b5eSEugene Zhulenev }
236a86a9b5eSEugene Zhulenev
237621ad468SEugene Zhulenev // Creates a new `async.token` in not-ready state.
mlirAsyncRuntimeCreateToken()2386fd9e59eSPaul Lietar extern "C" AsyncToken *mlirAsyncRuntimeCreateToken() {
2391fc98642SEugene Zhulenev AsyncToken *token = new AsyncToken(getDefaultAsyncRuntime());
24036ce915aSLei Zhang return token;
24136ce915aSLei Zhang }
24236ce915aSLei Zhang
243621ad468SEugene Zhulenev // Creates a new `async.value` in not-ready state.
mlirAsyncRuntimeCreateValue(int64_t size)24492db09cdSEugene Zhulenev extern "C" AsyncValue *mlirAsyncRuntimeCreateValue(int64_t size) {
2451fc98642SEugene Zhulenev AsyncValue *value = new AsyncValue(getDefaultAsyncRuntime(), size);
246621ad468SEugene Zhulenev return value;
247621ad468SEugene Zhulenev }
248621ad468SEugene Zhulenev
249c30ab6c2SEugene Zhulenev // Create a new `async.group` in empty state.
mlirAsyncRuntimeCreateGroup(int64_t size)250d43b2360SEugene Zhulenev extern "C" AsyncGroup *mlirAsyncRuntimeCreateGroup(int64_t size) {
251d43b2360SEugene Zhulenev AsyncGroup *group = new AsyncGroup(getDefaultAsyncRuntime(), size);
252c30ab6c2SEugene Zhulenev return group;
253c30ab6c2SEugene Zhulenev }
254c30ab6c2SEugene Zhulenev
mlirAsyncRuntimeAddTokenToGroup(AsyncToken * token,AsyncGroup * group)2553d95d1b4SEugene Zhulenev extern "C" int64_t mlirAsyncRuntimeAddTokenToGroup(AsyncToken *token,
2563d95d1b4SEugene Zhulenev AsyncGroup *group) {
257c30ab6c2SEugene Zhulenev std::unique_lock<std::mutex> lockToken(token->mu);
258c30ab6c2SEugene Zhulenev std::unique_lock<std::mutex> lockGroup(group->mu);
259c30ab6c2SEugene Zhulenev
260a86a9b5eSEugene Zhulenev // Get the rank of the token inside the group before we drop the reference.
261a86a9b5eSEugene Zhulenev int rank = group->rank.fetch_add(1);
262c30ab6c2SEugene Zhulenev
263d8c84d2aSEugene Zhulenev auto onTokenReady = [group, token]() {
264d8c84d2aSEugene Zhulenev // Increment the number of errors in the group.
265d8c84d2aSEugene Zhulenev if (State(token->state).isError())
266d8c84d2aSEugene Zhulenev group->numErrors.fetch_add(1);
267d8c84d2aSEugene Zhulenev
268d43b2360SEugene Zhulenev // If pending tokens go below zero it means that more tokens than the group
269d43b2360SEugene Zhulenev // size were added to this group.
270d43b2360SEugene Zhulenev assert(group->pendingTokens > 0 && "wrong group size");
271d43b2360SEugene Zhulenev
272c30ab6c2SEugene Zhulenev // Run all group awaiters if it was the last token in the group.
273c30ab6c2SEugene Zhulenev if (group->pendingTokens.fetch_sub(1) == 1) {
274c30ab6c2SEugene Zhulenev group->cv.notify_all();
275c30ab6c2SEugene Zhulenev for (auto &awaiter : group->awaiters)
276c30ab6c2SEugene Zhulenev awaiter();
277c30ab6c2SEugene Zhulenev }
278c30ab6c2SEugene Zhulenev };
279c30ab6c2SEugene Zhulenev
28039957aa4SEugene Zhulenev if (State(token->state).isAvailableOrError()) {
2813d95d1b4SEugene Zhulenev // Update group pending tokens immediately and maybe run awaiters.
2823d95d1b4SEugene Zhulenev onTokenReady();
2833d95d1b4SEugene Zhulenev
284a86a9b5eSEugene Zhulenev } else {
2853d95d1b4SEugene Zhulenev // Update group pending tokens when token will become ready. Because this
2863d95d1b4SEugene Zhulenev // will happen asynchronously we must ensure that `group` is alive until
2873d95d1b4SEugene Zhulenev // then, and re-ackquire the lock.
288a86a9b5eSEugene Zhulenev group->addRef();
2893d95d1b4SEugene Zhulenev
290e5639b3fSMehdi Amini token->awaiters.emplace_back([group, onTokenReady]() {
2913d95d1b4SEugene Zhulenev // Make sure that `dropRef` does not destroy the mutex owned by the lock.
2923d95d1b4SEugene Zhulenev {
2933d95d1b4SEugene Zhulenev std::unique_lock<std::mutex> lockGroup(group->mu);
2943d95d1b4SEugene Zhulenev onTokenReady();
2953d95d1b4SEugene Zhulenev }
2963d95d1b4SEugene Zhulenev group->dropRef();
2973d95d1b4SEugene Zhulenev });
298a86a9b5eSEugene Zhulenev }
299c30ab6c2SEugene Zhulenev
300a86a9b5eSEugene Zhulenev return rank;
301c30ab6c2SEugene Zhulenev }
302c30ab6c2SEugene Zhulenev
30339957aa4SEugene Zhulenev // Switches `async.token` to available or error state (terminatl state) and runs
30439957aa4SEugene Zhulenev // all awaiters.
setTokenState(AsyncToken * token,State state)30539957aa4SEugene Zhulenev static void setTokenState(AsyncToken *token, State state) {
30639957aa4SEugene Zhulenev assert(state.isAvailableOrError() && "must be terminal state");
30739957aa4SEugene Zhulenev assert(State(token->state).isUnavailable() && "token must be unavailable");
30839957aa4SEugene Zhulenev
3093d95d1b4SEugene Zhulenev // Make sure that `dropRef` does not destroy the mutex owned by the lock.
3103d95d1b4SEugene Zhulenev {
31136ce915aSLei Zhang std::unique_lock<std::mutex> lock(token->mu);
31239957aa4SEugene Zhulenev token->state = state;
31336ce915aSLei Zhang token->cv.notify_all();
31436ce915aSLei Zhang for (auto &awaiter : token->awaiters)
31536ce915aSLei Zhang awaiter();
3163d95d1b4SEugene Zhulenev }
317a86a9b5eSEugene Zhulenev
318a86a9b5eSEugene Zhulenev // Async tokens created with a ref count `2` to keep token alive until the
319a86a9b5eSEugene Zhulenev // async task completes. Drop this reference explicitly when token emplaced.
320a86a9b5eSEugene Zhulenev token->dropRef();
32136ce915aSLei Zhang }
32236ce915aSLei Zhang
setValueState(AsyncValue * value,State state)32339957aa4SEugene Zhulenev static void setValueState(AsyncValue *value, State state) {
32439957aa4SEugene Zhulenev assert(state.isAvailableOrError() && "must be terminal state");
32539957aa4SEugene Zhulenev assert(State(value->state).isUnavailable() && "value must be unavailable");
32639957aa4SEugene Zhulenev
327621ad468SEugene Zhulenev // Make sure that `dropRef` does not destroy the mutex owned by the lock.
328621ad468SEugene Zhulenev {
329621ad468SEugene Zhulenev std::unique_lock<std::mutex> lock(value->mu);
33039957aa4SEugene Zhulenev value->state = state;
331621ad468SEugene Zhulenev value->cv.notify_all();
332621ad468SEugene Zhulenev for (auto &awaiter : value->awaiters)
333621ad468SEugene Zhulenev awaiter();
334621ad468SEugene Zhulenev }
335621ad468SEugene Zhulenev
336621ad468SEugene Zhulenev // Async values created with a ref count `2` to keep value alive until the
337621ad468SEugene Zhulenev // async task completes. Drop this reference explicitly when value emplaced.
338621ad468SEugene Zhulenev value->dropRef();
339621ad468SEugene Zhulenev }
340621ad468SEugene Zhulenev
mlirAsyncRuntimeEmplaceToken(AsyncToken * token)34139957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeEmplaceToken(AsyncToken *token) {
34239957aa4SEugene Zhulenev setTokenState(token, State::kAvailable);
34339957aa4SEugene Zhulenev }
34439957aa4SEugene Zhulenev
mlirAsyncRuntimeEmplaceValue(AsyncValue * value)34539957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeEmplaceValue(AsyncValue *value) {
34639957aa4SEugene Zhulenev setValueState(value, State::kAvailable);
34739957aa4SEugene Zhulenev }
34839957aa4SEugene Zhulenev
mlirAsyncRuntimeSetTokenError(AsyncToken * token)34939957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeSetTokenError(AsyncToken *token) {
35039957aa4SEugene Zhulenev setTokenState(token, State::kError);
35139957aa4SEugene Zhulenev }
35239957aa4SEugene Zhulenev
mlirAsyncRuntimeSetValueError(AsyncValue * value)35339957aa4SEugene Zhulenev extern "C" void mlirAsyncRuntimeSetValueError(AsyncValue *value) {
35439957aa4SEugene Zhulenev setValueState(value, State::kError);
35539957aa4SEugene Zhulenev }
35639957aa4SEugene Zhulenev
mlirAsyncRuntimeIsTokenError(AsyncToken * token)35739957aa4SEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsTokenError(AsyncToken *token) {
35839957aa4SEugene Zhulenev return State(token->state).isError();
35939957aa4SEugene Zhulenev }
36039957aa4SEugene Zhulenev
mlirAsyncRuntimeIsValueError(AsyncValue * value)36139957aa4SEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsValueError(AsyncValue *value) {
36239957aa4SEugene Zhulenev return State(value->state).isError();
36339957aa4SEugene Zhulenev }
36439957aa4SEugene Zhulenev
mlirAsyncRuntimeIsGroupError(AsyncGroup * group)365d8c84d2aSEugene Zhulenev extern "C" bool mlirAsyncRuntimeIsGroupError(AsyncGroup *group) {
366d8c84d2aSEugene Zhulenev return group->numErrors.load() > 0;
367d8c84d2aSEugene Zhulenev }
368d8c84d2aSEugene Zhulenev
mlirAsyncRuntimeAwaitToken(AsyncToken * token)3696fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeAwaitToken(AsyncToken *token) {
37036ce915aSLei Zhang std::unique_lock<std::mutex> lock(token->mu);
37139957aa4SEugene Zhulenev if (!State(token->state).isAvailableOrError())
37239957aa4SEugene Zhulenev token->cv.wait(
37339957aa4SEugene Zhulenev lock, [token] { return State(token->state).isAvailableOrError(); });
374c30ab6c2SEugene Zhulenev }
375c30ab6c2SEugene Zhulenev
mlirAsyncRuntimeAwaitValue(AsyncValue * value)376621ad468SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitValue(AsyncValue *value) {
377621ad468SEugene Zhulenev std::unique_lock<std::mutex> lock(value->mu);
37839957aa4SEugene Zhulenev if (!State(value->state).isAvailableOrError())
37939957aa4SEugene Zhulenev value->cv.wait(
38039957aa4SEugene Zhulenev lock, [value] { return State(value->state).isAvailableOrError(); });
381621ad468SEugene Zhulenev }
382621ad468SEugene Zhulenev
mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup * group)3833d95d1b4SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitAllInGroup(AsyncGroup *group) {
384c30ab6c2SEugene Zhulenev std::unique_lock<std::mutex> lock(group->mu);
385c30ab6c2SEugene Zhulenev if (group->pendingTokens != 0)
386c30ab6c2SEugene Zhulenev group->cv.wait(lock, [group] { return group->pendingTokens == 0; });
38736ce915aSLei Zhang }
38836ce915aSLei Zhang
389621ad468SEugene Zhulenev // Returns a pointer to the storage owned by the async value.
mlirAsyncRuntimeGetValueStorage(AsyncValue * value)390621ad468SEugene Zhulenev extern "C" ValueStorage mlirAsyncRuntimeGetValueStorage(AsyncValue *value) {
39139957aa4SEugene Zhulenev assert(!State(value->state).isError() && "unexpected error state");
392621ad468SEugene Zhulenev return value->storage.data();
393621ad468SEugene Zhulenev }
394621ad468SEugene Zhulenev
mlirAsyncRuntimeExecute(CoroHandle handle,CoroResume resume)3956fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeExecute(CoroHandle handle, CoroResume resume) {
396bb0e6213SEugene Zhulenev auto *runtime = getDefaultAsyncRuntime();
397bb0e6213SEugene Zhulenev runtime->getThreadPool().async([handle, resume]() { (*resume)(handle); });
39836ce915aSLei Zhang }
39936ce915aSLei Zhang
mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken * token,CoroHandle handle,CoroResume resume)4006fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimeAwaitTokenAndExecute(AsyncToken *token,
4016fd9e59eSPaul Lietar CoroHandle handle,
40236ce915aSLei Zhang CoroResume resume) {
4033d95d1b4SEugene Zhulenev auto execute = [handle, resume]() { (*resume)(handle); };
404f63f28edSEugene Zhulenev std::unique_lock<std::mutex> lock(token->mu);
40539957aa4SEugene Zhulenev if (State(token->state).isAvailableOrError()) {
406f63f28edSEugene Zhulenev lock.unlock();
4073d95d1b4SEugene Zhulenev execute();
408a2223b09SEugene Zhulenev } else {
409e5639b3fSMehdi Amini token->awaiters.emplace_back([execute]() { execute(); });
41036ce915aSLei Zhang }
411a2223b09SEugene Zhulenev }
41236ce915aSLei Zhang
mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue * value,CoroHandle handle,CoroResume resume)413621ad468SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitValueAndExecute(AsyncValue *value,
414621ad468SEugene Zhulenev CoroHandle handle,
415621ad468SEugene Zhulenev CoroResume resume) {
416621ad468SEugene Zhulenev auto execute = [handle, resume]() { (*resume)(handle); };
417f63f28edSEugene Zhulenev std::unique_lock<std::mutex> lock(value->mu);
41839957aa4SEugene Zhulenev if (State(value->state).isAvailableOrError()) {
419f63f28edSEugene Zhulenev lock.unlock();
420621ad468SEugene Zhulenev execute();
421a2223b09SEugene Zhulenev } else {
422e5639b3fSMehdi Amini value->awaiters.emplace_back([execute]() { execute(); });
423621ad468SEugene Zhulenev }
424a2223b09SEugene Zhulenev }
425621ad468SEugene Zhulenev
mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup * group,CoroHandle handle,CoroResume resume)4263d95d1b4SEugene Zhulenev extern "C" void mlirAsyncRuntimeAwaitAllInGroupAndExecute(AsyncGroup *group,
4273d95d1b4SEugene Zhulenev CoroHandle handle,
428c30ab6c2SEugene Zhulenev CoroResume resume) {
4293d95d1b4SEugene Zhulenev auto execute = [handle, resume]() { (*resume)(handle); };
430f63f28edSEugene Zhulenev std::unique_lock<std::mutex> lock(group->mu);
431a2223b09SEugene Zhulenev if (group->pendingTokens == 0) {
432f63f28edSEugene Zhulenev lock.unlock();
4333d95d1b4SEugene Zhulenev execute();
434a2223b09SEugene Zhulenev } else {
435e5639b3fSMehdi Amini group->awaiters.emplace_back([execute]() { execute(); });
436c30ab6c2SEugene Zhulenev }
437a2223b09SEugene Zhulenev }
438c30ab6c2SEugene Zhulenev
mlirAsyncRuntimGetNumWorkerThreads()439149311b4Sbakhtiyar extern "C" int64_t mlirAsyncRuntimGetNumWorkerThreads() {
440744616b3SMehdi Amini return getDefaultAsyncRuntime()->getThreadPool().getMaxConcurrency();
441149311b4Sbakhtiyar }
442149311b4Sbakhtiyar
44336ce915aSLei Zhang //===----------------------------------------------------------------------===//
44436ce915aSLei Zhang // Small async runtime support library for testing.
44536ce915aSLei Zhang //===----------------------------------------------------------------------===//
44636ce915aSLei Zhang
mlirAsyncRuntimePrintCurrentThreadId()4476fd9e59eSPaul Lietar extern "C" void mlirAsyncRuntimePrintCurrentThreadId() {
44836ce915aSLei Zhang static thread_local std::thread::id thisId = std::this_thread::get_id();
449ac8b53fcSAdrian Kuegel std::cout << "Current thread id: " << thisId << '\n';
45036ce915aSLei Zhang }
45136ce915aSLei Zhang
4521fc98642SEugene Zhulenev //===----------------------------------------------------------------------===//
453530db6a3SIngo Müller // MLIR ExecutionEngine dynamic library integration.
4541fc98642SEugene Zhulenev //===----------------------------------------------------------------------===//
4551fc98642SEugene Zhulenev
456dd2dac2fSMatthew Parkinson // Visual Studio had a bug that fails to compile nested generic lambdas
457dd2dac2fSMatthew Parkinson // inside an `extern "C"` function.
458dd2dac2fSMatthew Parkinson // https://developercommunity.visualstudio.com/content/problem/475494/clexe-error-with-lambda-inside-function-templates.html
459dd2dac2fSMatthew Parkinson // The bug is fixed in VS2019 16.1. Separating the declaration and definition is
460dd2dac2fSMatthew Parkinson // a work around for older versions of Visual Studio.
46102b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
462f9bce19eSIngo Müller extern "C" MLIR_ASYNC_RUNTIME_EXPORT void
4630b3841ebSIngo Müller __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols);
464dd2dac2fSMatthew Parkinson
46502b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_execution_engine_init(llvm::StringMap<void * > & exportSymbols)4660b3841ebSIngo Müller void __mlir_execution_engine_init(llvm::StringMap<void *> &exportSymbols) {
4671fc98642SEugene Zhulenev auto exportSymbol = [&](llvm::StringRef name, auto ptr) {
4681fc98642SEugene Zhulenev assert(exportSymbols.count(name) == 0 && "symbol already exists");
4691fc98642SEugene Zhulenev exportSymbols[name] = reinterpret_cast<void *>(ptr);
4701fc98642SEugene Zhulenev };
4711fc98642SEugene Zhulenev
4721fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAddRef",
4731fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAddRef);
4741fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeDropRef",
4751fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeDropRef);
4761fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeExecute",
4771fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeExecute);
4781fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeGetValueStorage",
4791fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeGetValueStorage);
4801fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeCreateToken",
4811fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeCreateToken);
4821fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeCreateValue",
4831fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeCreateValue);
4841fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeEmplaceToken",
4851fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeEmplaceToken);
4861fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeEmplaceValue",
4871fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeEmplaceValue);
48839957aa4SEugene Zhulenev exportSymbol("mlirAsyncRuntimeSetTokenError",
48939957aa4SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeSetTokenError);
49039957aa4SEugene Zhulenev exportSymbol("mlirAsyncRuntimeSetValueError",
49139957aa4SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeSetValueError);
49239957aa4SEugene Zhulenev exportSymbol("mlirAsyncRuntimeIsTokenError",
49339957aa4SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeIsTokenError);
49439957aa4SEugene Zhulenev exportSymbol("mlirAsyncRuntimeIsValueError",
49539957aa4SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeIsValueError);
496d8c84d2aSEugene Zhulenev exportSymbol("mlirAsyncRuntimeIsGroupError",
497d8c84d2aSEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeIsGroupError);
4981fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitToken",
4991fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitToken);
5001fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitValue",
5011fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitValue);
5021fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitTokenAndExecute",
5031fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitTokenAndExecute);
5041fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitValueAndExecute",
5051fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitValueAndExecute);
5061fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeCreateGroup",
5071fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeCreateGroup);
5081fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAddTokenToGroup",
5091fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAddTokenToGroup);
5101fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitAllInGroup",
5111fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroup);
5121fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimeAwaitAllInGroupAndExecute",
5131fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimeAwaitAllInGroupAndExecute);
514149311b4Sbakhtiyar exportSymbol("mlirAsyncRuntimGetNumWorkerThreads",
515149311b4Sbakhtiyar &mlir::runtime::mlirAsyncRuntimGetNumWorkerThreads);
5161fc98642SEugene Zhulenev exportSymbol("mlirAsyncRuntimePrintCurrentThreadId",
5171fc98642SEugene Zhulenev &mlir::runtime::mlirAsyncRuntimePrintCurrentThreadId);
5181fc98642SEugene Zhulenev }
5191fc98642SEugene Zhulenev
52002b6fb21SMehdi Amini // NOLINTNEXTLINE(*-identifier-naming): externally called.
__mlir_execution_engine_destroy()521f9bce19eSIngo Müller extern "C" MLIR_ASYNC_RUNTIME_EXPORT void __mlir_execution_engine_destroy() {
5220b3841ebSIngo Müller resetDefaultAsyncRuntime();
5230b3841ebSIngo Müller }
5241fc98642SEugene Zhulenev
525dd2dac2fSMatthew Parkinson } // namespace runtime
526dd2dac2fSMatthew Parkinson } // namespace mlir
527