libs/capy/include/boost/capy/when_all.hpp

96.9% Lines (95/98) 90.7% Functions (284/313) 100.0% Branches (23/23)
libs/capy/include/boost/capy/when_all.hpp
Line Branch Hits Source Code
1 //
2 // Copyright (c) 2026 Steve Gerbino
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 // Official repository: https://github.com/cppalliance/capy
8 //
9
10 #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 #define BOOST_CAPY_WHEN_ALL_HPP
12
13 #include <boost/capy/detail/config.hpp>
14 #include <boost/capy/concept/executor.hpp>
15 #include <boost/capy/concept/io_launchable_task.hpp>
16 #include <boost/capy/coro.hpp>
17 #include <boost/capy/ex/executor_ref.hpp>
18 #include <boost/capy/ex/frame_allocator.hpp>
19 #include <boost/capy/task.hpp>
20
21 #include <array>
22 #include <atomic>
23 #include <exception>
24 #include <optional>
25 #include <stop_token>
26 #include <tuple>
27 #include <type_traits>
28 #include <utility>
29
30 namespace boost {
31 namespace capy {
32
33 namespace detail {
34
35 /** Type trait to filter void types from a tuple.
36
37 Void-returning tasks do not contribute a value to the result tuple.
38 This trait computes the filtered result type.
39
40 Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 */
42 template<typename T>
43 using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44
45 template<typename... Ts>
46 using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47
48 /** Holds the result of a single task within when_all.
49 */
50 template<typename T>
51 struct result_holder
52 {
53 std::optional<T> value_;
54
55 52 void set(T v)
56 {
57 52 value_ = std::move(v);
58 52 }
59
60 45 T get() &&
61 {
62 45 return std::move(*value_);
63 }
64 };
65
66 /** Specialization for void tasks - no value storage needed.
67 */
68 template<>
69 struct result_holder<void>
70 {
71 };
72
73 /** Shared state for when_all operation.
74
75 @tparam Ts The result types of the tasks.
76 */
77 template<typename... Ts>
78 struct when_all_state
79 {
80 static constexpr std::size_t task_count = sizeof...(Ts);
81
82 // Completion tracking - when_all waits for all children
83 std::atomic<std::size_t> remaining_count_;
84
85 // Result storage in input order
86 std::tuple<result_holder<Ts>...> results_;
87
88 // Runner handles - destroyed in await_resume while allocator is valid
89 std::array<coro, task_count> runner_handles_{};
90
91 // Exception storage - first error wins, others discarded
92 std::atomic<bool> has_exception_{false};
93 std::exception_ptr first_exception_;
94
95 // Stop propagation - on error, request stop for siblings
96 std::stop_source stop_source_;
97
98 // Connects parent's stop_token to our stop_source
99 struct stop_callback_fn
100 {
101 std::stop_source* source_;
102 2 void operator()() const { source_->request_stop(); }
103 };
104 using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 std::optional<stop_callback_t> parent_stop_callback_;
106
107 // Parent resumption
108 coro continuation_;
109 executor_ref caller_ex_;
110
111 28 when_all_state()
112
1/1
✓ Branch 5 taken 28 times.
28 : remaining_count_(task_count)
113 {
114 28 }
115
116 // Runners self-destruct in final_suspend. No destruction needed here.
117
118 /** Capture an exception (first one wins).
119 */
120 11 void capture_exception(std::exception_ptr ep)
121 {
122 11 bool expected = false;
123
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 3 times.
11 if(has_exception_.compare_exchange_strong(
124 expected, true, std::memory_order_relaxed))
125 8 first_exception_ = ep;
126 11 }
127
128 };
129
130 /** Wrapper coroutine that intercepts task completion.
131
132 This runner awaits its assigned task and stores the result in
133 the shared state, or captures the exception and requests stop.
134 */
135 template<typename T, typename... Ts>
136 struct when_all_runner
137 {
138 struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
139 {
140 when_all_state<Ts...>* state_ = nullptr;
141 executor_ref ex_;
142 std::stop_token stop_token_;
143
144 68 when_all_runner get_return_object()
145 {
146 68 return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
147 }
148
149 68 std::suspend_always initial_suspend() noexcept
150 {
151 68 return {};
152 }
153
154 68 auto final_suspend() noexcept
155 {
156 struct awaiter
157 {
158 promise_type* p_;
159
160 8 bool await_ready() const noexcept
161 {
162 8 return false;
163 }
164
165 8 void await_suspend(coro h) noexcept
166 {
167 // Extract everything needed for signaling before
168 // self-destruction. Inline dispatch may destroy
169 // when_all_state, so we can't access members after.
170 8 auto* state = p_->state_;
171 8 auto* counter = &state->remaining_count_;
172 8 auto caller_ex = state->caller_ex_;
173 8 auto cont = state->continuation_;
174
175 // Self-destruct first - state no longer destroys runners
176 8 h.destroy();
177
178 // Signal completion. If last, dispatch parent.
179 // Uses only local copies - safe even if state
180 // is destroyed during inline dispatch.
181 8 auto remaining = counter->fetch_sub(1, std::memory_order_acq_rel);
182
2/2
✓ Branch 0 taken 4 times.
✓ Branch 1 taken 4 times.
8 if(remaining == 1)
183 4 caller_ex.dispatch(cont);
184 8 }
185
186 void await_resume() const noexcept
187 {
188 }
189 };
190 68 return awaiter{this};
191 }
192
193 57 void return_void()
194 {
195 57 }
196
197 11 void unhandled_exception()
198 {
199 11 state_->capture_exception(std::current_exception());
200 // Request stop for sibling tasks
201 11 state_->stop_source_.request_stop();
202 11 }
203
204 template<class Awaitable>
205 struct transform_awaiter
206 {
207 std::decay_t<Awaitable> a_;
208 promise_type* p_;
209
210 68 bool await_ready()
211 {
212 68 return a_.await_ready();
213 }
214
215 68 decltype(auto) await_resume()
216 {
217 68 return a_.await_resume();
218 }
219
220 template<class Promise>
221 68 auto await_suspend(std::coroutine_handle<Promise> h)
222 {
223
1/1
✓ Branch 3 taken 54 times.
68 return a_.await_suspend(h, p_->ex_, p_->stop_token_);
224 }
225 };
226
227 template<class Awaitable>
228 68 auto await_transform(Awaitable&& a)
229 {
230 using A = std::decay_t<Awaitable>;
231 if constexpr (IoAwaitable<A>)
232 {
233 return transform_awaiter<Awaitable>{
234 136 std::forward<Awaitable>(a), this};
235 }
236 else
237 {
238 static_assert(sizeof(A) == 0, "requires IoAwaitable");
239 }
240 68 }
241 };
242
243 std::coroutine_handle<promise_type> h_;
244
245 68 explicit when_all_runner(std::coroutine_handle<promise_type> h)
246 68 : h_(h)
247 {
248 68 }
249
250 // Enable move for all clang versions - some versions need it
251 when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
252
253 // Non-copyable
254 when_all_runner(when_all_runner const&) = delete;
255 when_all_runner& operator=(when_all_runner const&) = delete;
256 when_all_runner& operator=(when_all_runner&&) = delete;
257
258 68 auto release() noexcept
259 {
260 68 return std::exchange(h_, nullptr);
261 }
262 };
263
264 /** Create a runner coroutine for a single task.
265
266 Task is passed directly to ensure proper coroutine frame storage.
267 */
268 template<std::size_t Index, typename T, typename... Ts>
269 when_all_runner<T, Ts...>
270
1/1
✓ Branch 1 taken 68 times.
68 make_when_all_runner(task<T> inner, when_all_state<Ts...>* state)
271 {
272 if constexpr (std::is_void_v<T>)
273 {
274 co_await std::move(inner);
275 }
276 else
277 {
278 std::get<Index>(state->results_).set(co_await std::move(inner));
279 }
280 136 }
281
282 /** Internal awaitable that launches all runner coroutines and waits.
283
284 This awaitable is used inside the when_all coroutine to handle
285 the concurrent execution of child tasks.
286 */
287 template<typename... Ts>
288 class when_all_launcher
289 {
290 std::tuple<task<Ts>...>* tasks_;
291 when_all_state<Ts...>* state_;
292
293 public:
294 28 when_all_launcher(
295 std::tuple<task<Ts>...>* tasks,
296 when_all_state<Ts...>* state)
297 28 : tasks_(tasks)
298 28 , state_(state)
299 {
300 28 }
301
302 28 bool await_ready() const noexcept
303 {
304 28 return sizeof...(Ts) == 0;
305 }
306
307 28 coro await_suspend(coro continuation, executor_ref caller_ex, std::stop_token parent_token = {})
308 {
309 28 state_->continuation_ = continuation;
310 28 state_->caller_ex_ = caller_ex;
311
312 // Forward parent's stop requests to children
313
2/2
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 22 times.
28 if(parent_token.stop_possible())
314 {
315 12 state_->parent_stop_callback_.emplace(
316 parent_token,
317 6 typename when_all_state<Ts...>::stop_callback_fn{&state_->stop_source_});
318
319
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 4 times.
6 if(parent_token.stop_requested())
320 2 state_->stop_source_.request_stop();
321 }
322
323 // CRITICAL: If the last task finishes synchronously then the parent
324 // coroutine resumes, destroying its frame, and destroying this object
325 // prior to the completion of await_suspend. Therefore, await_suspend
326 // must ensure `this` cannot be referenced after calling `launch_one`
327 // for the last time.
328 28 auto token = state_->stop_source_.get_token();
329 [&]<std::size_t... Is>(std::index_sequence<Is...>) {
330
2/2
✓ Branch 2 taken 4 times.
✓ Branch 6 taken 4 times.
4 (..., launch_one<Is>(caller_ex, token));
331
2/2
✓ Branch 1 taken 24 times.
✓ Branch 1 taken 4 times.
28 }(std::index_sequence_for<Ts...>{});
332
333 // Let signal_completion() handle resumption
334 56 return std::noop_coroutine();
335 28 }
336
337 28 void await_resume() const noexcept
338 {
339 // Results are extracted by the when_all coroutine from state
340 28 }
341
342 private:
343 template<std::size_t I>
344 68 void launch_one(executor_ref caller_ex, std::stop_token token)
345 {
346
1/1
✓ Branch 2 taken 68 times.
68 auto runner = make_when_all_runner<I>(
347 68 std::move(std::get<I>(*tasks_)), state_);
348
349 68 auto h = runner.release();
350 68 h.promise().state_ = state_;
351 68 h.promise().ex_ = caller_ex;
352 68 h.promise().stop_token_ = token;
353
354 68 coro ch{h};
355 68 state_->runner_handles_[I] = ch;
356
1/1
✓ Branch 1 taken 68 times.
68 state_->caller_ex_.dispatch(ch);
357 68 }
358 };
359
360 /** Compute the result type for when_all.
361
362 Returns void when all tasks are void (P2300 aligned),
363 otherwise returns a tuple with void types filtered out.
364 */
365 template<typename... Ts>
366 using when_all_result_t = std::conditional_t<
367 std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
368 void,
369 filter_void_tuple_t<Ts...>>;
370
371 /** Helper to extract a single result, returning empty tuple for void.
372 This is a separate function to work around a GCC-11 ICE that occurs
373 when using nested immediately-invoked lambdas with pack expansion.
374 */
375 template<std::size_t I, typename... Ts>
376 47 auto extract_single_result(when_all_state<Ts...>& state)
377 {
378 using T = std::tuple_element_t<I, std::tuple<Ts...>>;
379 if constexpr (std::is_void_v<T>)
380 2 return std::tuple<>();
381 else
382
1/1
✓ Branch 4 taken 45 times.
45 return std::make_tuple(std::move(std::get<I>(state.results_)).get());
383 }
384
385 /** Extract results from state, filtering void types.
386 */
387 template<typename... Ts>
388 19 auto extract_results(when_all_state<Ts...>& state)
389 {
390 19 return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
391
3/3
✓ Branch 1 taken 4 times.
✓ Branch 4 taken 4 times.
✓ Branch 7 taken 4 times.
4 return std::tuple_cat(extract_single_result<Is>(state)...);
392
1/1
✓ Branch 1 taken 19 times.
38 }(std::index_sequence_for<Ts...>{});
393 }
394
395 } // namespace detail
396
397 /** Execute multiple tasks concurrently and collect their results.
398
399 Launches all tasks simultaneously and waits for all to complete
400 before returning. Results are collected in input order. If any
401 task throws, cancellation is requested for siblings and the first
402 exception is rethrown after all tasks complete.
403
404 @li All child tasks run concurrently on the caller's executor
405 @li Results are returned as a tuple in input order
406 @li Void-returning tasks do not contribute to the result tuple
407 @li If all tasks return void, `when_all` returns `task<void>`
408 @li First exception wins; subsequent exceptions are discarded
409 @li Stop is requested for siblings on first error
410 @li Completes only after all children have finished
411
412 @par Thread Safety
413 The returned task must be awaited from a single execution context.
414 Child tasks execute concurrently but complete through the caller's
415 executor.
416
417 @param tasks The tasks to execute concurrently. Each task is
418 consumed (moved-from) when `when_all` is awaited.
419
420 @return A task yielding a tuple of non-void results. Returns
421 `task<void>` when all input tasks return void.
422
423 @par Example
424
425 @code
426 task<> example()
427 {
428 // Concurrent fetch, results collected in order
429 auto [user, posts] = co_await when_all(
430 fetch_user( id ), // task<User>
431 fetch_posts( id ) // task<std::vector<Post>>
432 );
433
434 // Void tasks don't contribute to result
435 co_await when_all(
436 log_event( "start" ), // task<void>
437 notify_user( id ) // task<void>
438 );
439 // Returns task<void>, no result tuple
440 }
441 @endcode
442
443 @see task
444 */
445 template<typename... Ts>
446 [[nodiscard]] task<detail::when_all_result_t<Ts...>>
447
1/1
✓ Branch 1 taken 28 times.
28 when_all(task<Ts>... tasks)
448 {
449 using result_type = detail::when_all_result_t<Ts...>;
450
451 // State is stored in the coroutine frame, using the frame allocator
452 detail::when_all_state<Ts...> state;
453
454 // Store tasks in the frame
455 std::tuple<task<Ts>...> task_tuple(std::move(tasks)...);
456
457 // Launch all tasks and wait for completion
458 co_await detail::when_all_launcher<Ts...>(&task_tuple, &state);
459
460 // Propagate first exception if any.
461 // Safe without explicit acquire: capture_exception() is sequenced-before
462 // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
463 // last task's decrement that resumes this coroutine.
464 if(state.first_exception_)
465 std::rethrow_exception(state.first_exception_);
466
467 // Extract and return results
468 if constexpr (std::is_void_v<result_type>)
469 co_return;
470 else
471 co_return detail::extract_results(state);
472 56 }
473
474 /// Compute the result type of `when_all` for the given task types.
475 template<typename... Ts>
476 using when_all_result_type = detail::when_all_result_t<Ts...>;
477
478 } // namespace capy
479 } // namespace boost
480
481 #endif
482