xref: /llvm-project/libcxx/test/std/language.support/support.coroutines/end.to.end/go.pass.cpp (revision 70248920fcd804a5825ecf69f24b96a7e340afe6)
1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // UNSUPPORTED: c++03, c++11, c++14, c++17
10 
11 #include <cassert>
12 #include <coroutine>
13 #include <memory>
14 
15 #include "test_macros.h"
16 
17 bool cancel = false;
18 
19 struct goroutine
20 {
21   static int const N = 10;
22   static int count;
23   static std::coroutine_handle<> stack[N];
24 
schedulegoroutine25   static void schedule(std::coroutine_handle<>& rh)
26   {
27     assert(count < N);
28     stack[count++] = rh;
29     rh = nullptr;
30   }
31 
32   goroutine() = default;
33   goroutine(const goroutine&) = default;
34   goroutine& operator=(const goroutine&) = default;
~goroutinegoroutine35   ~goroutine() {}
36 
run_onegoroutine37   static void run_one()
38   {
39     assert(count > 0);
40     stack[--count]();
41   }
42 
43   struct promise_type
44   {
initial_suspendgoroutine::promise_type45     std::suspend_never initial_suspend() {
46       return {};
47     }
final_suspendgoroutine::promise_type48     std::suspend_never final_suspend() noexcept { return {}; }
return_voidgoroutine::promise_type49     void return_void() {}
get_return_objectgoroutine::promise_type50     goroutine get_return_object() {
51       return{};
52     }
unhandled_exceptiongoroutine::promise_type53     void unhandled_exception() {}
54   };
55 };
56 int goroutine::count;
57 std::coroutine_handle<> goroutine::stack[N];
58 
59 std::coroutine_handle<goroutine::promise_type> workaround;
60 
61 class channel;
62 
63 struct push_awaiter {
64   channel* ch;
await_readypush_awaiter65   bool await_ready() {return false; }
66   void await_suspend(std::coroutine_handle<> rh);
await_resumepush_awaiter67   void await_resume() {}
68 };
69 
70 struct pull_awaiter {
71   channel * ch;
72 
73   bool await_ready();
74   void await_suspend(std::coroutine_handle<> rh);
75   int await_resume();
76 };
77 
78 class channel
79 {
80   using T = int;
81 
82   friend struct push_awaiter;
83   friend struct pull_awaiter;
84 
85   T const* pvalue = nullptr;
86   std::coroutine_handle<> reader = nullptr;
87   std::coroutine_handle<> writer = nullptr;
88 public:
push(T const & value)89   push_awaiter push(T const& value)
90   {
91     assert(pvalue == nullptr);
92     assert(!writer);
93     pvalue = &value;
94 
95     return { this };
96   }
97 
pull()98   pull_awaiter pull()
99   {
100     assert(!reader);
101 
102     return { this };
103   }
104 
sync_push(T const & value)105   void sync_push(T const& value)
106   {
107     assert(!pvalue);
108     pvalue = &value;
109     assert(reader);
110     reader();
111     assert(!pvalue);
112     reader = nullptr;
113   }
114 
sync_pull()115   auto sync_pull()
116   {
117     while (!pvalue) goroutine::run_one();
118     auto result = *pvalue;
119     pvalue = nullptr;
120     if (writer)
121     {
122       auto wr = writer;
123       writer = nullptr;
124       wr();
125     }
126     return result;
127   }
128 };
129 
await_suspend(std::coroutine_handle<> rh)130 void push_awaiter::await_suspend(std::coroutine_handle<> rh)
131 {
132   ch->writer = rh;
133   if (ch->reader) goroutine::schedule(ch->reader);
134 }
135 
136 
await_ready()137 bool pull_awaiter::await_ready() {
138   return !!ch->writer;
139 }
await_suspend(std::coroutine_handle<> rh)140 void pull_awaiter::await_suspend(std::coroutine_handle<> rh) {
141   ch->reader = rh;
142 }
await_resume()143 int pull_awaiter::await_resume() {
144   auto result = *ch->pvalue;
145   ch->pvalue = nullptr;
146   if (ch->writer) {
147     //goroutine::schedule(ch->writer);
148     auto wr = ch->writer;
149     ch->writer = nullptr;
150     wr();
151   }
152   return result;
153 }
154 
pusher(channel & left,channel & right)155 goroutine pusher(channel& left, channel& right)
156 {
157   for (;;) {
158     auto val = co_await left.pull();
159     co_await right.push(val + 1);
160   }
161 }
162 
163 const int N = 100;
164 channel c[N + 1];
165 
main(int,char **)166 int main(int, char**) {
167   for (int i = 0; i < N; ++i)
168     pusher(c[i], c[i + 1]);
169 
170   c[0].sync_push(0);
171   int result = c[N].sync_pull();
172 
173   assert(result == 100);
174 
175   return 0;
176 }
177