diff --git a/include/boost/capy.hpp b/include/boost/capy.hpp index 8002136..e066756 100644 --- a/include/boost/capy.hpp +++ b/include/boost/capy.hpp @@ -63,6 +63,7 @@ #include #include #include +#include #include #endif diff --git a/include/boost/capy/frame_allocator.hpp b/include/boost/capy/frame_allocator.hpp index 9466c71..10d36ed 100644 --- a/include/boost/capy/frame_allocator.hpp +++ b/include/boost/capy/frame_allocator.hpp @@ -16,6 +16,7 @@ #include #include #include +#include namespace boost { namespace capy { diff --git a/include/boost/capy/when_all.hpp b/include/boost/capy/when_all.hpp new file mode 100644 index 0000000..527d36e --- /dev/null +++ b/include/boost/capy/when_all.hpp @@ -0,0 +1,521 @@ +// +// Copyright (c) 2026 Steve Gerbino +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +// Official repository: https://github.com/cppalliance/capy +// + +#ifndef BOOST_CAPY_WHEN_ALL_HPP +#define BOOST_CAPY_WHEN_ALL_HPP + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace boost { +namespace capy { + +namespace detail { + +/** Type trait to filter void types from a tuple. + + Void-returning tasks do not contribute a value to the result tuple. + This trait computes the filtered result type. + + Example: filter_void_tuple_t = tuple +*/ +template +struct filter_void_tuple; + +template<> +struct filter_void_tuple<> +{ + using type = std::tuple<>; +}; + +template +struct filter_void_tuple +{ +private: + using rest_type = typename filter_void_tuple::type; + +public: + using type = std::conditional_t< + std::is_void_v, + rest_type, + decltype(std::tuple_cat( + std::declval>(), + std::declval()))>; +}; + +template +using filter_void_tuple_t = typename filter_void_tuple::type; + +/** Holds the result of a single task within when_all. +*/ +template +struct result_holder +{ + std::optional value_; + + void set(T v) + { + value_ = std::move(v); + } + + T get() && + { + return std::move(*value_); + } +}; + +/** Specialization for void tasks - no value storage needed. +*/ +template<> +struct result_holder +{ + void set() + { + } +}; + +/** Shared state for when_all operation. + + @tparam Ts The result types of the tasks. +*/ +template +struct when_all_state +{ + static constexpr std::size_t task_count = sizeof...(Ts); + + // Completion tracking - when_all waits for all children + std::atomic remaining_count_; + + // Result storage in input order + std::tuple...> results_; + + // Runner handles - destroyed in await_resume while allocator is valid + std::array runner_handles_{}; + + // Exception storage - first error wins, others discarded + std::atomic has_exception_{false}; + std::exception_ptr first_exception_; + + // Stop propagation - on error, request stop for siblings + std::stop_source stop_source_; + + // Connects parent's stop_token to our stop_source + struct stop_callback_fn + { + std::stop_source* source_; + void operator()() const { source_->request_stop(); } + }; + using stop_callback_t = std::stop_callback; + std::optional parent_stop_callback_; + + // Parent resumption + coro continuation_; + any_dispatcher caller_dispatcher_; + + explicit when_all_state(std::size_t count) + : remaining_count_(count) + { + } + + ~when_all_state() + { + destroy_runners(); + } + + void store_runner(std::size_t index, coro h) + { + runner_handles_[index] = h; + } + + void destroy_runners() + { + for(auto& h : runner_handles_) + { + if(h) + { + h.destroy(); + h = nullptr; + } + } + } + + /** Capture an exception (first one wins). + */ + bool capture_exception(std::exception_ptr ep) + { + bool expected = false; + if(has_exception_.compare_exchange_strong( + expected, true, std::memory_order_relaxed)) + { + first_exception_ = ep; + return true; + } + return false; + } + + /** Signal that a task has completed. + + The last child to complete triggers resumption of the parent. + */ + coro signal_completion() + { + auto remaining = remaining_count_.fetch_sub(1, std::memory_order_acq_rel); + if(remaining == 1) + return caller_dispatcher_(continuation_); + return std::noop_coroutine(); + } + +}; + +/** Wrapper coroutine that intercepts task completion. + + This runner awaits its assigned task and stores the result in + the shared state, or captures the exception and requests stop. +*/ +template +struct when_all_runner +{ + struct promise_type : frame_allocating_base + { + when_all_state* state_ = nullptr; + any_dispatcher ex_; + std::stop_token stop_token_; + + when_all_runner get_return_object() + { + return when_all_runner(std::coroutine_handle::from_promise(*this)); + } + + std::suspend_always initial_suspend() noexcept + { + return {}; + } + + auto final_suspend() noexcept + { + struct awaiter + { + promise_type* p_; + + bool await_ready() const noexcept + { + return false; + } + + coro await_suspend(coro) noexcept + { + // Signal completion; last task resumes parent + return p_->state_->signal_completion(); + } + + void await_resume() const noexcept + { + } + }; + return awaiter{this}; + } + + void return_void() + { + } + + void unhandled_exception() + { + state_->capture_exception(std::current_exception()); + // Request stop for sibling tasks + state_->stop_source_.request_stop(); + } + + template + struct transform_awaiter + { + std::decay_t a_; + promise_type* p_; + + bool await_ready() + { + return a_.await_ready(); + } + + auto await_resume() + { + return a_.await_resume(); + } + + template + auto await_suspend(std::coroutine_handle h) + { + using A = std::decay_t; + // Propagate stop_token to nested awaitables + if constexpr (stoppable_awaitable) + return a_.await_suspend(h, p_->ex_, p_->stop_token_); + else + return a_.await_suspend(h, p_->ex_); + } + }; + + template + auto await_transform(Awaitable&& a) + { + using A = std::decay_t; + if constexpr (affine_awaitable) + { + return transform_awaiter{ + std::forward(a), this}; + } + else + { + return make_affine(std::forward(a), ex_); + } + } + }; + + std::coroutine_handle h_; + + explicit when_all_runner(std::coroutine_handle h) + : h_(h) + { + } + + ~when_all_runner() + { + if(h_) + h_.destroy(); + } + + when_all_runner(when_all_runner const&) = delete; + when_all_runner& operator=(when_all_runner const&) = delete; + + when_all_runner(when_all_runner&& other) noexcept + : h_(std::exchange(other.h_, nullptr)) + { + } + + when_all_runner& operator=(when_all_runner&& other) noexcept + { + if(this != &other) + { + if(h_) + h_.destroy(); + h_ = std::exchange(other.h_, nullptr); + } + return *this; + } + + auto release() noexcept + { + return std::exchange(h_, nullptr); + } +}; + +/** Create a runner coroutine for a single task. +*/ +template +when_all_runner +make_when_all_runner(task inner, when_all_state* state) +{ + if constexpr (std::is_void_v) + { + co_await std::move(inner); + std::get(state->results_).set(); + } + else + { + auto result = co_await std::move(inner); + std::get(state->results_).set(std::move(result)); + } +} + +} // namespace detail + +/** Awaitable that concurrently executes multiple tasks. + + @tparam Ts The return types of the tasks. + + Key features: + @li All child tasks are launched concurrently + @li Results are collected in input order + @li First error is captured; subsequent errors are discarded + @li On error, stop is requested for all siblings + @li Completes only after all children have completed + @li Void tasks do not contribute to the result tuple +*/ +template +class when_all_awaitable +{ + using state_type = detail::when_all_state; + using filtered_tuple = detail::filter_void_tuple_t; + +public: + /** Result type with void tasks filtered out. + Returns void when all tasks are void (P2300 aligned). + */ + using result_type = std::conditional_t< + std::is_same_v>, + void, + filtered_tuple>; + +private: + std::tuple...> tasks_; + std::unique_ptr state_; + +public: + explicit when_all_awaitable(task... tasks) + : tasks_(std::move(tasks)...) + { + } + + when_all_awaitable(when_all_awaitable const&) = delete; + when_all_awaitable& operator=(when_all_awaitable const&) = delete; + when_all_awaitable(when_all_awaitable&&) = default; + when_all_awaitable& operator=(when_all_awaitable&&) = default; + + bool await_ready() const noexcept + { + return sizeof...(Ts) == 0; + } + + /** Affine awaitable protocol. + */ + template + coro await_suspend(coro continuation, D const& caller_ex) + { + return await_suspend_impl(continuation, caller_ex, std::stop_token{}); + } + + /** Stoppable awaitable protocol. + */ + template + coro await_suspend(coro continuation, D const& caller_ex, std::stop_token token) + { + return await_suspend_impl(continuation, caller_ex, token); + } + + /** Extract results or propagate the first captured error. + */ + result_type await_resume() + { + if(state_->first_exception_) + std::rethrow_exception(state_->first_exception_); + + if constexpr (std::is_void_v) + return; + else + return extract_results_impl<0>(); + } + +private: + template + coro await_suspend_impl(coro continuation, D const& caller_ex, std::stop_token parent_token) + { + state_ = std::make_unique(sizeof...(Ts)); + state_->continuation_ = continuation; + state_->caller_dispatcher_ = caller_ex; + + // Forward parent's stop requests to children + if(parent_token.stop_possible()) + { + state_->parent_stop_callback_.emplace( + parent_token, + typename state_type::stop_callback_fn{&state_->stop_source_}); + + if(parent_token.stop_requested()) + state_->stop_source_.request_stop(); + } + + // Launch all tasks concurrently + launch_all(caller_ex, std::index_sequence_for{}); + + // Let signal_completion() handle resumption to avoid double-resume + return std::noop_coroutine(); + } + + template + void launch_all(D const& caller_ex, std::index_sequence) + { + (..., launch_one(caller_ex)); + } + + template + void launch_one(D const& caller_ex) + { + using T = std::tuple_element_t>; + + auto runner = detail::make_when_all_runner( + std::move(std::get(tasks_)), state_.get()); + + auto h = runner.release(); + h.promise().state_ = state_.get(); + h.promise().ex_ = caller_ex; + + // Give child a stop_token connected to our stop_source + h.promise().stop_token_ = state_->stop_source_.get_token(); + + state_->store_runner(I, coro{h}); + + // Start the task via dispatcher + caller_ex(coro{h}).resume(); + } + + // Recursively build result tuple, skipping void types + template + auto extract_results_impl() + { + if constexpr (I >= sizeof...(Ts)) + return std::tuple<>(); + else + { + using T = std::tuple_element_t>; + if constexpr (std::is_void_v) + return extract_results_impl(); + else + return std::tuple_cat( + std::make_tuple(std::move(std::get(state_->results_)).get()), + extract_results_impl()); + } + } +}; + +/** Wait for all tasks to complete concurrently. + + @par Example + @code + task example() { + auto [a, b] = co_await when_all( + fetch_int(), // task + fetch_string() // task + ); + } + @endcode + + @param tasks The tasks to execute concurrently. + @return An awaitable yielding a tuple of results. +*/ +template +[[nodiscard]] auto when_all(task... tasks) +{ + return when_all_awaitable(std::move(tasks)...); +} + +} // namespace capy +} // namespace boost + +#endif diff --git a/test/unit/when_all.cpp b/test/unit/when_all.cpp new file mode 100644 index 0000000..c3253fe --- /dev/null +++ b/test/unit/when_all.cpp @@ -0,0 +1,1136 @@ +// +// Copyright (c) 2025 Vinnie Falco (vinnie dot falco at gmail dot com) +// +// Distributed under the Boost Software License, Version 1.0. (See accompanying +// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt) +// +// Official repository: https://github.com/cppalliance/capy +// + +// Test that header file is self-contained. +#include + +#include +#include +#include + +#include "test_suite.hpp" + +#include +#include +#include +#include + +namespace boost { +namespace capy { + +// Static assertions for void filtering type trait +static_assert(std::is_same_v< + detail::filter_void_tuple_t, + std::tuple>); +static_assert(std::is_same_v< + detail::filter_void_tuple_t, + std::tuple<>>); +static_assert(std::is_same_v< + detail::filter_void_tuple_t, + std::tuple>); +static_assert(std::is_same_v< + detail::filter_void_tuple_t, + std::tuple<>>); + +// Verify result_type: void when all tasks are void, tuple otherwise +static_assert(std::is_same_v< + when_all_awaitable::result_type, + std::tuple>); +static_assert(std::is_same_v< + when_all_awaitable::result_type, + std::tuple>); +static_assert(std::is_void_v< + when_all_awaitable::result_type>); +static_assert(std::is_void_v< + when_all_awaitable::result_type>); +static_assert(std::is_void_v< + when_all_awaitable::result_type>); + +// Verify when_all_awaitable satisfies awaitable protocols +static_assert(affine_awaitable, any_dispatcher>); +static_assert(stoppable_awaitable, any_dispatcher>); + +/** Simple synchronous dispatcher for testing. +*/ +struct test_dispatcher +{ + int* dispatch_count_; + + explicit test_dispatcher(int& count) + : dispatch_count_(&count) + { + } + + coro operator()(coro h) const + { + ++(*dispatch_count_); + return h; + } +}; + +static_assert(dispatcher); + +struct test_exception : std::runtime_error +{ + explicit test_exception(const char* msg) + : std::runtime_error(msg) + { + } +}; + +[[noreturn]] inline void +throw_test_exception(char const* msg) +{ + throw test_exception(msg); +} + +struct when_all_test +{ + // Helper tasks + static task + returns_int(int value) + { + co_return value; + } + + static task + returns_string(std::string value) + { + co_return value; + } + + static task + void_task() + { + co_return; + } + + static task + throws_exception(char const* msg) + { + throw_test_exception(msg); + co_return 0; + } + + static task + void_throws_exception(char const* msg) + { + throw_test_exception(msg); + co_return; + } + + // Test: All tasks succeed + void + testAllSucceed() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + int result = 0; + + async_run(d)( + []() -> task { + auto [a, b] = co_await when_all( + returns_int(10), + returns_int(20) + ); + co_return a + b; + }(), + [&](int r) { completed = true; result = r; }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + BOOST_TEST_EQ(result, 30); + } + + // Test: Three tasks succeed + void + testThreeTasksSucceed() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + int result = 0; + + async_run(d)( + []() -> task { + auto [a, b, c] = co_await when_all( + returns_int(1), + returns_int(2), + returns_int(3) + ); + co_return a + b + c; + }(), + [&](int r) { completed = true; result = r; }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + BOOST_TEST_EQ(result, 6); + } + + // Test: Mixed types (int, string, void) + void + testMixedTypes() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + std::string result; + + async_run(d)( + []() -> task { + // void_task() doesn't contribute to result tuple + auto [a, b] = co_await when_all( + returns_int(42), + returns_string("hello"), + void_task() + ); + co_return b + std::to_string(a); + }(), + [&](std::string r) { completed = true; result = std::move(r); }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + BOOST_TEST_EQ(result, "hello42"); + } + + // Test: Single task in when_all + void + testSingleTask() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + int result = 0; + + async_run(d)( + []() -> task { + auto [a] = co_await when_all( + returns_int(99) + ); + co_return a; + }(), + [&](int r) { completed = true; result = r; }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + BOOST_TEST_EQ(result, 99); + } + + // Test: First exception captured + void + testFirstException() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + bool caught_exception = false; + std::string error_msg; + + async_run(d)( + []() -> task { + auto [a, b] = co_await when_all( + throws_exception("first error"), + returns_int(10) + ); + co_return a + b; + }(), + [&](int) { completed = true; }, + [&](std::exception_ptr ep) { + try { + std::rethrow_exception(ep); + } catch (test_exception const& e) { + caught_exception = true; + error_msg = e.what(); + } + }); + + BOOST_TEST(!completed); + BOOST_TEST(caught_exception); + BOOST_TEST_EQ(error_msg, "first error"); + } + + // Test: Multiple failures - first exception wins + void + testMultipleFailuresFirstWins() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool caught_exception = false; + std::string error_msg; + + async_run(d)( + []() -> task { + auto [a, b, c] = co_await when_all( + throws_exception("error_1"), + throws_exception("error_2"), + throws_exception("error_3") + ); + co_return a + b + c; + }(), + [](int) {}, + [&](std::exception_ptr ep) { + try { + std::rethrow_exception(ep); + } catch (test_exception const& e) { + caught_exception = true; + error_msg = e.what(); + } + }); + + BOOST_TEST(caught_exception); + BOOST_TEST( + error_msg == "error_1" || + error_msg == "error_2" || + error_msg == "error_3"); + } + + // Test: Void task throws exception + void + testVoidTaskException() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool caught_exception = false; + std::string error_msg; + + async_run(d)( + []() -> task { + auto [a] = co_await when_all( + returns_int(10), + void_throws_exception("void error") + ); + co_return a; + }(), + [](int) {}, + [&](std::exception_ptr ep) { + try { + std::rethrow_exception(ep); + } catch (test_exception const& e) { + caught_exception = true; + error_msg = e.what(); + } + }); + + BOOST_TEST(caught_exception); + BOOST_TEST_EQ(error_msg, "void error"); + } + + // Test: Nested when_all calls + void + testNestedWhenAll() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + int result = 0; + + async_run(d)( + []() -> task { + auto inner1 = []() -> task { + auto [a, b] = co_await when_all( + returns_int(1), + returns_int(2) + ); + co_return a + b; + }; + + auto inner2 = []() -> task { + auto [a, b] = co_await when_all( + returns_int(3), + returns_int(4) + ); + co_return a + b; + }; + + auto [x, y] = co_await when_all( + inner1(), + inner2() + ); + + co_return x + y; + }(), + [&](int r) { completed = true; result = r; }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + BOOST_TEST_EQ(result, 10); // (1+2) + (3+4) = 10 + } + + // Test: All void tasks return void (not empty tuple) + void + testAllVoidTasks() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + + async_run(d)( + []() -> task { + // All void tasks return void, not std::tuple<> + co_await when_all( + void_task(), + void_task(), + void_task() + ); + co_return; + }(), + [&]() { completed = true; }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + } + + // Test: Result type correctness - void types filtered, all-void returns void + void + testResultType() + { + // Mixed types: void filtered out + using mixed_result = when_all_awaitable::result_type; + static_assert(std::is_same_v< + mixed_result, + std::tuple>); + + // All void: returns void (not empty tuple) + using all_void_result = when_all_awaitable::result_type; + static_assert(std::is_void_v); + + // Single void: returns void + using single_void_result = when_all_awaitable::result_type; + static_assert(std::is_void_v); + } + + //---------------------------------------------------------- + // Frame allocator verification tests + //---------------------------------------------------------- + + /** Counting frame allocator to verify all coroutines use the allocator. + */ + struct counting_frame_allocator + { + std::size_t* alloc_count_; + std::size_t* dealloc_count_; + + void* allocate(std::size_t n) + { + ++(*alloc_count_); + return ::operator new(n); + } + + void deallocate(void* p, std::size_t) + { + ++(*dealloc_count_); + ::operator delete(p); + } + }; + + static_assert(frame_allocator); + + // Test: Frame allocator used for two tasks + // Expected: 1 async_run_task + 1 outer task + 2 runners + 2 child tasks = 6 + void + testFrameAllocatorTwoTasks() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + std::size_t alloc_count = 0; + std::size_t dealloc_count = 0; + counting_frame_allocator alloc{&alloc_count, &dealloc_count}; + bool completed = false; + + async_run(d, alloc)( + []() -> task { + auto [a, b] = co_await when_all( + returns_int(10), + returns_int(20) + ); + co_return a + b; + }(), + [&](int r) { + completed = true; + BOOST_TEST_EQ(r, 30); + }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + // 1 async_run_task + 1 outer task + 2 when_all_runners + 2 child tasks = 6 + BOOST_TEST_EQ(alloc_count, 6u); + BOOST_TEST_EQ(dealloc_count, 6u); + } + + // Test: Frame allocator used for three tasks + // Expected: 1 async_run_task + 1 outer task + 3 runners + 3 child tasks = 8 + void + testFrameAllocatorThreeTasks() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + std::size_t alloc_count = 0; + std::size_t dealloc_count = 0; + counting_frame_allocator alloc{&alloc_count, &dealloc_count}; + bool completed = false; + + async_run(d, alloc)( + []() -> task { + auto [a, b, c] = co_await when_all( + returns_int(1), + returns_int(2), + returns_int(3) + ); + co_return a + b + c; + }(), + [&](int r) { + completed = true; + BOOST_TEST_EQ(r, 6); + }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + // 1 async_run_task + 1 outer task + 3 when_all_runners + 3 child tasks = 8 + BOOST_TEST_EQ(alloc_count, 8u); + BOOST_TEST_EQ(dealloc_count, 8u); + } + + // Test: Frame allocator with void tasks + // Expected: 1 async_run_task + 1 outer task + 3 runners + 3 child tasks = 8 + void + testFrameAllocatorWithVoidTasks() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + std::size_t alloc_count = 0; + std::size_t dealloc_count = 0; + counting_frame_allocator alloc{&alloc_count, &dealloc_count}; + bool completed = false; + + async_run(d, alloc)( + []() -> task { + auto [a] = co_await when_all( + returns_int(42), + void_task(), + void_task() + ); + co_return a; + }(), + [&](int r) { + completed = true; + BOOST_TEST_EQ(r, 42); + }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + // 1 async_run_task + 1 outer task + 3 when_all_runners + 3 child tasks = 8 + BOOST_TEST_EQ(alloc_count, 8u); + BOOST_TEST_EQ(dealloc_count, 8u); + } + + // Test: Frame allocator with single task (edge case) + // Expected: 1 async_run_task + 1 outer task + 1 runner + 1 child task = 4 + void + testFrameAllocatorSingleTask() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + std::size_t alloc_count = 0; + std::size_t dealloc_count = 0; + counting_frame_allocator alloc{&alloc_count, &dealloc_count}; + bool completed = false; + + async_run(d, alloc)( + []() -> task { + auto [a] = co_await when_all( + returns_int(99) + ); + co_return a; + }(), + [&](int r) { + completed = true; + BOOST_TEST_EQ(r, 99); + }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + // 1 async_run_task + 1 outer task + 1 when_all_runner + 1 child task = 4 + BOOST_TEST_EQ(alloc_count, 4u); + BOOST_TEST_EQ(dealloc_count, 4u); + } + + // Test: Frame allocator with nested when_all + // Expected: 1 async_run + 1 outer task + 2 outer runners + 2 inner tasks + + // 2*(2 inner runners + 2 inner child tasks) = 2 + 2 + 2 + 8 = 14 + void + testFrameAllocatorNestedWhenAll() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + std::size_t alloc_count = 0; + std::size_t dealloc_count = 0; + counting_frame_allocator alloc{&alloc_count, &dealloc_count}; + bool completed = false; + + async_run(d, alloc)( + []() -> task { + auto inner1 = []() -> task { + auto [a, b] = co_await when_all( + returns_int(1), + returns_int(2) + ); + co_return a + b; + }; + + auto inner2 = []() -> task { + auto [a, b] = co_await when_all( + returns_int(3), + returns_int(4) + ); + co_return a + b; + }; + + auto [x, y] = co_await when_all( + inner1(), + inner2() + ); + + co_return x + y; + }(), + [&](int r) { + completed = true; + BOOST_TEST_EQ(r, 10); + }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + // Structure: + // - 1 async_run_task + // - 1 outer task (the main lambda) + // - 2 when_all_runners (for inner1(), inner2()) + // - 2 child tasks (inner1, inner2) + // - For each inner when_all: + // - 2 when_all_runners + // - 2 child tasks (returns_int) + // Total: 1 + 1 + 2 + 2 + 2*(2 + 2) = 14 + BOOST_TEST_EQ(alloc_count, 14u); + BOOST_TEST_EQ(dealloc_count, 14u); + } + + // Test: Frame allocator deallocations match allocations on exception + void + testFrameAllocatorWithException() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + std::size_t alloc_count = 0; + std::size_t dealloc_count = 0; + counting_frame_allocator alloc{&alloc_count, &dealloc_count}; + bool caught_exception = false; + + async_run(d, alloc)( + []() -> task { + auto [a, b] = co_await when_all( + throws_exception("test error"), + returns_int(10) + ); + co_return a + b; + }(), + [](int) {}, + [&](std::exception_ptr) { + caught_exception = true; + }); + + BOOST_TEST(caught_exception); + // Even with exception, all frames should be deallocated + // 1 async_run_task + 1 outer task + 2 when_all_runners + 2 child tasks = 6 + BOOST_TEST_EQ(alloc_count, 6u); + BOOST_TEST_EQ(dealloc_count, 6u); + } + + //---------------------------------------------------------- + // Stop token propagation tests + //---------------------------------------------------------- + + // Helper: task that records if stop was requested + static task + checks_stop_token(std::atomic& stop_was_requested) + { + // This task just returns immediately, but in real usage + // you would check stop_token in a loop + co_return 42; + } + + // Helper: stoppable task that honors stop requests + static task + stoppable_task(std::atomic& counter) + { + ++counter; + co_return counter.load(); + } + + // Test: Stop is requested when a sibling fails + void + testStopRequestedOnError() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool caught_exception = false; + + async_run(d)( + []() -> task { + auto [a, b] = co_await when_all( + throws_exception("error"), + returns_int(10) + ); + co_return a + b; + }(), + [](int) {}, + [&](std::exception_ptr) { + caught_exception = true; + }); + + // Exception should propagate - stop was requested internally + BOOST_TEST(caught_exception); + } + + // Test: All tasks complete even after stop is requested + void + testAllTasksCompleteAfterStop() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + std::atomic completion_count{0}; + bool caught_exception = false; + + auto counting_task = [&]() -> task { + ++completion_count; + co_return 1; + }; + + auto failing_task = [&]() -> task { + ++completion_count; + throw_test_exception("fail"); + co_return 0; + }; + + async_run(d)( + [&]() -> task { + auto [a, b, c] = co_await when_all( + counting_task(), + failing_task(), + counting_task() + ); + co_return a + b + c; + }(), + [](int) {}, + [&](std::exception_ptr) { + caught_exception = true; + }); + + BOOST_TEST(caught_exception); + // All three tasks should have run to completion + BOOST_TEST_EQ(completion_count.load(), 3); + } + + //---------------------------------------------------------- + // Edge case tests + //---------------------------------------------------------- + + // Test: Large number of tasks + void + testManyTasks() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + int result = 0; + + async_run(d)( + []() -> task { + auto [a, b, c, d, e, f, g, h] = co_await when_all( + returns_int(1), + returns_int(2), + returns_int(3), + returns_int(4), + returns_int(5), + returns_int(6), + returns_int(7), + returns_int(8) + ); + co_return a + b + c + d + e + f + g + h; + }(), + [&](int r) { completed = true; result = r; }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + BOOST_TEST_EQ(result, 36); // 1+2+3+4+5+6+7+8 = 36 + } + + // Test: Task that does multiple internal operations + static task + multi_step_task(int start) + { + int value = start; + // Simulate multiple steps by nesting tasks + value += co_await returns_int(1); + value += co_await returns_int(2); + co_return value; + } + + void + testTasksWithMultipleSteps() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + int result = 0; + + async_run(d)( + []() -> task { + auto [a, b] = co_await when_all( + multi_step_task(10), + multi_step_task(20) + ); + co_return a + b; + }(), + [&](int r) { completed = true; result = r; }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + // (10+1+2) + (20+1+2) = 13 + 23 = 36 + BOOST_TEST_EQ(result, 36); + } + + // Test: Different exception types - first wins + struct other_exception : std::runtime_error + { + explicit other_exception(const char* msg) + : std::runtime_error(msg) + { + } + }; + + static task + throws_other_exception(char const* msg) + { + throw other_exception(msg); + co_return 0; + } + + void + testDifferentExceptionTypes() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool caught_test = false; + bool caught_other = false; + + async_run(d)( + []() -> task { + auto [a, b] = co_await when_all( + throws_exception("test"), + throws_other_exception("other") + ); + co_return a + b; + }(), + [](int) {}, + [&](std::exception_ptr ep) { + try { + std::rethrow_exception(ep); + } catch (test_exception const&) { + caught_test = true; + } catch (other_exception const&) { + caught_other = true; + } + }); + + // One of them should be caught (first to fail wins) + BOOST_TEST(caught_test || caught_other); + // But not both + BOOST_TEST(!(caught_test && caught_other)); + } + + //---------------------------------------------------------- + // Dispatcher propagation tests + //---------------------------------------------------------- + + // Dispatcher that tracks which tasks were dispatched + struct tracking_dispatcher + { + std::atomic* dispatch_count_; + + explicit tracking_dispatcher(std::atomic& count) + : dispatch_count_(&count) + { + } + + coro operator()(coro h) const + { + ++(*dispatch_count_); + return h; + } + }; + + static_assert(dispatcher); + + void + testDispatcherUsedForAllTasks() + { + std::atomic dispatch_count{0}; + tracking_dispatcher d(dispatch_count); + bool completed = false; + + async_run(d)( + []() -> task { + auto [a, b, c] = co_await when_all( + returns_int(1), + returns_int(2), + returns_int(3) + ); + co_return a + b + c; + }(), + [&](int r) { + completed = true; + BOOST_TEST_EQ(r, 6); + }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + // Dispatcher should be called for: + // - async_run initial dispatch + // - when_all runners (3) + // - signal_completion resumption + BOOST_TEST(dispatch_count.load() > 0); + } + + //---------------------------------------------------------- + // Result ordering tests + //---------------------------------------------------------- + + // Test: Results are in input order regardless of completion order + void + testResultsInInputOrder() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + + async_run(d)( + []() -> task { + auto [first, second, third] = co_await when_all( + returns_string("first"), + returns_string("second"), + returns_string("third") + ); + BOOST_TEST_EQ(first, "first"); + BOOST_TEST_EQ(second, "second"); + BOOST_TEST_EQ(third, "third"); + co_return; + }(), + [&]() { completed = true; }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + } + + // Test: Mixed void and value results maintain order + void + testMixedVoidValueOrder() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + + async_run(d)( + []() -> task { + // void at index 1, values at 0 and 2 + auto [a, b] = co_await when_all( + returns_int(100), + void_task(), + returns_int(300) + ); + // a should be from index 0, b from index 2 + BOOST_TEST_EQ(a, 100); + BOOST_TEST_EQ(b, 300); + co_return; + }(), + [&]() { completed = true; }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + } + + //---------------------------------------------------------- + // Awaitable lifecycle tests + //---------------------------------------------------------- + + // Test: when_all_awaitable is move constructible + void + testAwaitableMoveConstruction() + { + auto awaitable1 = when_all(returns_int(1), returns_int(2)); + auto awaitable2 = std::move(awaitable1); + + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + + async_run(d)( + [aw = std::move(awaitable2)]() mutable -> task { + auto [a, b] = co_await std::move(aw); + co_return a + b; + }(), + [&](int r) { + completed = true; + BOOST_TEST_EQ(r, 3); + }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + } + + // Test: when_all can be stored and awaited later + void + testDeferredAwait() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool completed = false; + + auto deferred = when_all(returns_int(10), returns_int(20)); + + async_run(d)( + [aw = std::move(deferred)]() mutable -> task { + // Await later + auto [a, b] = co_await std::move(aw); + co_return a + b; + }(), + [&](int r) { + completed = true; + BOOST_TEST_EQ(r, 30); + }, + [](std::exception_ptr) {}); + + BOOST_TEST(completed); + } + + //---------------------------------------------------------- + // Stoppable awaitable protocol tests + //---------------------------------------------------------- + + // Test: when_all satisfies stoppable_awaitable concept + void + testStoppableAwaitableConcept() + { + static_assert(stoppable_awaitable< + when_all_awaitable, + any_dispatcher>); + + static_assert(stoppable_awaitable< + when_all_awaitable, + any_dispatcher>); + } + + // Test: Nested when_all propagates stop + void + testNestedWhenAllStopPropagation() + { + int dispatch_count = 0; + test_dispatcher d(dispatch_count); + bool caught_exception = false; + + async_run(d)( + []() -> task { + auto inner_failing = []() -> task { + auto [a, b] = co_await when_all( + throws_exception("inner error"), + returns_int(1) + ); + co_return a + b; + }; + + auto inner_success = []() -> task { + auto [a, b] = co_await when_all( + returns_int(2), + returns_int(3) + ); + co_return a + b; + }; + + auto [x, y] = co_await when_all( + inner_failing(), + inner_success() + ); + co_return x + y; + }(), + [](int) {}, + [&](std::exception_ptr ep) { + caught_exception = true; + try { + std::rethrow_exception(ep); + } catch (test_exception const& e) { + BOOST_TEST_EQ(std::string(e.what()), "inner error"); + } + }); + + BOOST_TEST(caught_exception); + } + + void + run() + { + // Basic functionality + testResultType(); + testAllSucceed(); + testThreeTasksSucceed(); + testMixedTypes(); + testSingleTask(); + testFirstException(); + testMultipleFailuresFirstWins(); + testVoidTaskException(); + testNestedWhenAll(); + testAllVoidTasks(); + + // Frame allocator verification + testFrameAllocatorTwoTasks(); + testFrameAllocatorThreeTasks(); + testFrameAllocatorWithVoidTasks(); + testFrameAllocatorSingleTask(); + testFrameAllocatorNestedWhenAll(); + testFrameAllocatorWithException(); + + // Stop token propagation + testStopRequestedOnError(); + testAllTasksCompleteAfterStop(); + + // Edge cases + testManyTasks(); + testTasksWithMultipleSteps(); + testDifferentExceptionTypes(); + + // Dispatcher propagation + testDispatcherUsedForAllTasks(); + + // Result ordering + testResultsInInputOrder(); + testMixedVoidValueOrder(); + + // Awaitable lifecycle + testAwaitableMoveConstruction(); + testDeferredAwait(); + + // Stoppable awaitable protocol + testStoppableAwaitableConcept(); + testNestedWhenAllStopPropagation(); + } +}; + +TEST_SUITE( + when_all_test, + "boost.capy.when_all"); + +} // capy +} // boost