Lightweight 0.20260617.0
Loading...
Searching...
No Matches
SyncWait.hpp
1// SPDX-License-Identifier: Apache-2.0
2#pragma once
3
4#include "Task.hpp"
5
6#include <atomic>
7#include <functional>
8#include <semaphore>
9#include <type_traits>
10
11namespace Lightweight::Async
12{
13
14namespace detail
15{
16
17 /// One-shot completion signal used to bridge a coroutine back to a blocking caller.
18 ///
19 /// Supports both a blocking wait (@ref Wait, used by @ref SyncWait) and a pollable
20 /// flag (@ref IsSet, used by @ref SyncWaitPumping which drives an executor instead of
21 /// blocking).
22 class SyncWaitEvent
23 {
24 public:
25 /// Installs a waker invoked from @ref Set after the completion flag is published.
26 ///
27 /// @ref SyncWaitPumping uses this so a completion that lands on a thread other than the
28 /// pumping thread still nudges the pumped executor to re-check its predicate, instead of
29 /// sleeping forever on a flag that was set without notifying the executor's condition
30 /// variable. Install it before the driver is started.
31 void SetWaker(std::function<void()> waker)
32 {
33 _waker = std::move(waker);
34 }
35
36 void Set() noexcept
37 {
38 _set.store(true, std::memory_order_release);
39 if (_waker)
40 _waker(); // publish the flag (above) before waking, so the woken thread observes it
41 _semaphore.release();
42 }
43
44 void Wait() noexcept
45 {
46 _semaphore.acquire();
47 }
48
49 [[nodiscard]] bool IsSet() const noexcept
50 {
51 return _set.load(std::memory_order_acquire);
52 }
53
54 private:
55 std::atomic<bool> _set { false };
56 std::binary_semaphore _semaphore { 0 };
57 std::function<void()> _waker;
58 };
59
60 template <typename T>
61 class SyncWaitTask;
62
63 /// Final-suspension awaiter for the SyncWait driver: raises the promise's completion signal once
64 /// the driver has fully suspended. Lifted to namespace scope (a local class may not have member
65 /// templates) so a single awaiter serves every @c SyncWaitPromise specialization.
66 struct SyncWaitFinalAwaiter
67 {
68 [[nodiscard]] bool await_ready() const noexcept
69 {
70 return false;
71 }
72 template <typename Promise>
73 void await_suspend(std::coroutine_handle<Promise> coro) const noexcept
74 {
75 coro.promise().Signal();
76 }
77 void await_resume() const noexcept {}
78 };
79
80 /// Shared machinery for the SyncWait driver promise, independent of the result type.
81 ///
82 /// Mirrors @c TaskPromiseBase: it owns the completion signal and the lazy/final-suspension
83 /// behavior, leaving only the result storage to the per-type specialization. The completion signal
84 /// is raised from the @c final_suspend awaiter (i.e. only once the driver has fully suspended), so
85 /// the blocking caller can safely destroy the coroutine frame without racing the resuming thread.
86 class SyncWaitPromiseBase
87 {
88 public:
89 std::suspend_always initial_suspend() noexcept
90 {
91 return {};
92 }
93
94 SyncWaitFinalAwaiter final_suspend() noexcept
95 {
96 return {};
97 }
98
99 void Bind(SyncWaitEvent& signal) noexcept
100 {
101 _event = &signal;
102 }
103
104 void Signal() const noexcept
105 {
106 _event->Set();
107 }
108
109 private:
110 SyncWaitEvent* _event = nullptr;
111 };
112
113 /// Promise for the internal driver coroutine used by SyncWait (non-void result).
114 template <typename T>
115 class SyncWaitPromise final: public SyncWaitPromiseBase
116 {
117 public:
118 SyncWaitTask<T> get_return_object() noexcept;
119
120 void unhandled_exception() noexcept
121 {
122 _result.SetException();
123 }
124
125 void return_value(T&& value) noexcept(std::is_nothrow_move_constructible_v<T>)
126 {
127 _result.SetValue(std::move(value));
128 }
129
130 void return_value(T const& value)
131 {
132 _result.SetValue(value);
133 }
134
135 [[nodiscard]] T Take()
136 {
137 return _result.Take();
138 }
139
140 private:
141 CoroutineResult<T> _result;
142 };
143
144 template <>
145 class SyncWaitPromise<void> final: public SyncWaitPromiseBase
146 {
147 public:
148 SyncWaitTask<void> get_return_object() noexcept;
149
150 void unhandled_exception() noexcept
151 {
152 _result.SetException();
153 }
154
155 void return_void() noexcept {}
156
157 void Take() const
158 {
159 _result.Take();
160 }
161
162 private:
163 CoroutineResult<void> _result;
164 };
165
166 /// Internal RAII owner of the SyncWait driver coroutine.
167 template <typename T>
168 class SyncWaitTask
169 {
170 public:
171 using promise_type = SyncWaitPromise<T>;
172 using Handle = std::coroutine_handle<promise_type>;
173
174 explicit SyncWaitTask(Handle handle) noexcept:
175 _handle { handle }
176 {
177 }
178
179 SyncWaitTask(SyncWaitTask&& other) noexcept:
180 _handle { std::exchange(other._handle, {}) }
181 {
182 }
183
184 SyncWaitTask(SyncWaitTask const&) = delete;
185 SyncWaitTask& operator=(SyncWaitTask const&) = delete;
186 SyncWaitTask& operator=(SyncWaitTask&&) = delete;
187
188 ~SyncWaitTask()
189 {
190 if (_handle)
191 _handle.destroy();
192 }
193
194 void Start(SyncWaitEvent& signal)
195 {
196 _handle.promise().Bind(signal);
197 _handle.resume();
198 }
199
200 [[nodiscard]] promise_type& Promise() noexcept
201 {
202 return _handle.promise();
203 }
204
205 private:
206 Handle _handle;
207 };
208
209 template <typename T>
210 SyncWaitTask<T> SyncWaitPromise<T>::get_return_object() noexcept
211 {
212 return SyncWaitTask<T> { std::coroutine_handle<SyncWaitPromise<T>>::from_promise(*this) };
213 }
214
215 inline SyncWaitTask<void> SyncWaitPromise<void>::get_return_object() noexcept
216 {
217 return SyncWaitTask<void> { std::coroutine_handle<SyncWaitPromise<void>>::from_promise(*this) };
218 }
219
220 template <typename T>
221 SyncWaitTask<T> MakeSyncWaitTask(Task<T> task)
222 {
223 if constexpr (std::is_void_v<T>)
224 co_await std::move(task);
225 else
226 co_return co_await std::move(task);
227 }
228
229} // namespace detail
230
231/// Blocks the calling thread until @p task completes and returns its result (or rethrows).
232///
233/// Use this from non-coroutine code (e.g. @c main, tests) when the task resumes on a
234/// thread-backed scheduler. If the task resumes on a @ref ManualExecutor that only the
235/// calling thread pumps, use @c SyncWaitPumping instead to avoid a deadlock.
236///
237/// @tparam T The task's result type.
238/// @param task The task to drive to completion (consumed).
239/// @return The task's result, or @c void.
240template <typename T>
241T SyncWait(Task<T> task)
242{
243 detail::SyncWaitEvent signal;
244 auto driver = detail::MakeSyncWaitTask<T>(std::move(task));
245 driver.Start(signal);
246 signal.Wait();
247 return driver.Promise().Take();
248}
249
250/// Drives @p task to completion by pumping @p executor on the calling thread.
251///
252/// This is the single-threaded counterpart to @c SyncWait: rather than blocking, it runs
253/// the executor's queued work (including the task's resumption) until the task finishes.
254/// Intended for an app/event-loop thread and for deterministic tests using a
255/// @ref ManualExecutor.
256///
257/// @tparam T The task's result type.
258/// @tparam Executor An executor exposing @c RunUntil(predicate).
259/// @param task The task to drive to completion (consumed).
260/// @param executor The executor to pump (typically a @ref ManualExecutor).
261/// @return The task's result, or @c void.
262template <typename T, typename Executor>
263T SyncWaitPumping(Task<T> task, Executor& executor)
264{
265 detail::SyncWaitEvent signal;
266 // If the task completes on a thread other than this pumping thread, nudge the executor so
267 // RunUntil re-checks the predicate instead of sleeping forever (a no-op item enqueues under
268 // the executor's lock and notifies its condition variable). Install before starting the driver.
269 signal.SetWaker([&executor] { executor.Post([] {}); });
270 auto driver = detail::MakeSyncWaitTask<T>(std::move(task));
271 driver.Start(signal);
272 executor.RunUntil([&signal] { return signal.IsSet(); });
273 // RunUntil returns as soon as the completion flag is observed, but Set() may still be mid-execution
274 // on the resuming thread (it raises the flag and nudges the waker before its trailing
275 // _semaphore.release()). Acquire the semaphore — release() is the last statement of Set() — so the
276 // local `signal` is not destroyed while Set() is still touching it. Mirrors SyncWait's handshake.
277 signal.Wait();
278 return driver.Promise().Take();
279}
280
281} // namespace Lightweight::Async