1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 // UNSUPPORTED: c++03, c++11, c++14, c++17 10 11 #include <cassert> 12 #include <coroutine> 13 #include <memory> 14 15 #include "test_macros.h" 16 17 bool cancel = false; 18 19 struct goroutine 20 { 21 static int const N = 10; 22 static int count; 23 static std::coroutine_handle<> stack[N]; 24 schedulegoroutine25 static void schedule(std::coroutine_handle<>& rh) 26 { 27 assert(count < N); 28 stack[count++] = rh; 29 rh = nullptr; 30 } 31 32 goroutine() = default; 33 goroutine(const goroutine&) = default; 34 goroutine& operator=(const goroutine&) = default; ~goroutinegoroutine35 ~goroutine() {} 36 run_onegoroutine37 static void run_one() 38 { 39 assert(count > 0); 40 stack[--count](); 41 } 42 43 struct promise_type 44 { initial_suspendgoroutine::promise_type45 std::suspend_never initial_suspend() { 46 return {}; 47 } final_suspendgoroutine::promise_type48 std::suspend_never final_suspend() noexcept { return {}; } return_voidgoroutine::promise_type49 void return_void() {} get_return_objectgoroutine::promise_type50 goroutine get_return_object() { 51 return{}; 52 } unhandled_exceptiongoroutine::promise_type53 void unhandled_exception() {} 54 }; 55 }; 56 int goroutine::count; 57 std::coroutine_handle<> goroutine::stack[N]; 58 59 std::coroutine_handle<goroutine::promise_type> workaround; 60 61 class channel; 62 63 struct push_awaiter { 64 channel* ch; await_readypush_awaiter65 bool await_ready() {return false; } 66 void await_suspend(std::coroutine_handle<> rh); await_resumepush_awaiter67 void await_resume() {} 68 }; 69 70 struct pull_awaiter { 71 channel * ch; 72 73 bool await_ready(); 74 void await_suspend(std::coroutine_handle<> rh); 75 int await_resume(); 76 }; 77 78 class channel 79 { 80 using T = int; 81 82 friend struct push_awaiter; 83 friend struct pull_awaiter; 84 85 T const* pvalue = nullptr; 86 std::coroutine_handle<> reader = nullptr; 87 std::coroutine_handle<> writer = nullptr; 88 public: push(T const & value)89 push_awaiter push(T const& value) 90 { 91 assert(pvalue == nullptr); 92 assert(!writer); 93 pvalue = &value; 94 95 return { this }; 96 } 97 pull()98 pull_awaiter pull() 99 { 100 assert(!reader); 101 102 return { this }; 103 } 104 sync_push(T const & value)105 void sync_push(T const& value) 106 { 107 assert(!pvalue); 108 pvalue = &value; 109 assert(reader); 110 reader(); 111 assert(!pvalue); 112 reader = nullptr; 113 } 114 sync_pull()115 auto sync_pull() 116 { 117 while (!pvalue) goroutine::run_one(); 118 auto result = *pvalue; 119 pvalue = nullptr; 120 if (writer) 121 { 122 auto wr = writer; 123 writer = nullptr; 124 wr(); 125 } 126 return result; 127 } 128 }; 129 await_suspend(std::coroutine_handle<> rh)130void push_awaiter::await_suspend(std::coroutine_handle<> rh) 131 { 132 ch->writer = rh; 133 if (ch->reader) goroutine::schedule(ch->reader); 134 } 135 136 await_ready()137bool pull_awaiter::await_ready() { 138 return !!ch->writer; 139 } await_suspend(std::coroutine_handle<> rh)140void pull_awaiter::await_suspend(std::coroutine_handle<> rh) { 141 ch->reader = rh; 142 } await_resume()143int pull_awaiter::await_resume() { 144 auto result = *ch->pvalue; 145 ch->pvalue = nullptr; 146 if (ch->writer) { 147 //goroutine::schedule(ch->writer); 148 auto wr = ch->writer; 149 ch->writer = nullptr; 150 wr(); 151 } 152 return result; 153 } 154 pusher(channel & left,channel & right)155goroutine pusher(channel& left, channel& right) 156 { 157 for (;;) { 158 auto val = co_await left.pull(); 159 co_await right.push(val + 1); 160 } 161 } 162 163 const int N = 100; 164 channel c[N + 1]; 165 main(int,char **)166int main(int, char**) { 167 for (int i = 0; i < N; ++i) 168 pusher(c[i], c[i + 1]); 169 170 c[0].sync_push(0); 171 int result = c[N].sync_pull(); 172 173 assert(result == 100); 174 175 return 0; 176 } 177