diff --git a/Tests/coro.cpp b/Tests/coro.cpp index c2bb97a..d66f062 100644 --- a/Tests/coro.cpp +++ b/Tests/coro.cpp @@ -10,6 +10,7 @@ #include "skullc/coro/sleep.hpp" #include "skullc/coro/task.hpp" #include "skullc/coro/this_coro.hpp" +#include "skullc/coro/composition.hpp" #include @@ -317,3 +318,89 @@ TEST_CASE("Signal awaiters work.", "[coro],[signal]") } } } + +namespace +{ + +skullc::coro::Task<> test_wait_all() +{ + co_await skullc::coro::wait_all( + test_sleepy_coro(0, 10ms), + test_sleepy_coro(1, 20ms) + ); + + co_return; +} + +skullc::coro::Task test_sleepy_coro_return(const int expected, const std::chrono::duration& duration = 10ms) +{ + co_await skullc::coro::sleep(0ms, duration); + REQUIRE(expected == test_coro_called); + test_coro_called++; + co_return expected; +} + +skullc::coro::Task<> testwait_first() +{ + auto val = co_await skullc::coro::wait_first( + test_sleepy_coro_return(0, 10ms), + test_sleepy_coro_return(1, 20ms) + ); + + REQUIRE(std::get<0>(val) == 0); + REQUIRE(std::get<1>(val) == std::nullopt); + + co_return; +} + +} + +TEST_CASE("Wait all awaiter works.", "[coro],[wait_all]") +{ + using namespace skullc::coro; + Scheduler scheduler; + skullc::this_coro::scheduler.register_scheduler(scheduler); + + test_coro_called = 0; + + scheduler.start_tasks(test_wait_all()); + scheduler.loop(0); + scheduler.loop(0); + scheduler.loop(0); + + REQUIRE(test_coro_called == 0); + scheduler.loop(10); + scheduler.loop(10); + + REQUIRE(test_coro_called == 1); + scheduler.loop(20); + scheduler.loop(20); + scheduler.loop(20); + + REQUIRE(test_coro_called == 2); +} + +TEST_CASE("Wait one awaiter works.", "[coro],[wait_first]") +{ + using namespace skullc::coro; + Scheduler scheduler; + skullc::this_coro::scheduler.register_scheduler(scheduler); + + test_coro_called = 0; + + scheduler.start_tasks(testwait_first()); + scheduler.loop(0); + scheduler.loop(0); + scheduler.loop(0); + + REQUIRE(test_coro_called == 0); + scheduler.loop(10); + scheduler.loop(10); + + REQUIRE(test_coro_called == 1); + scheduler.loop(20); + scheduler.loop(20); + scheduler.loop(20); + + REQUIRE(test_coro_called == 1); +} diff --git a/coro/inc/skullc/coro/composition.hpp b/coro/inc/skullc/coro/composition.hpp new file mode 100644 index 0000000..a15211c --- /dev/null +++ b/coro/inc/skullc/coro/composition.hpp @@ -0,0 +1,197 @@ +// +// Created by erki on 25/02/25. +// + +#pragma once + +#include "skullc/coro/this_coro.hpp" +#include "skullc/coro/task.hpp" + +#include +#include + +namespace skullc::coro +{ + +namespace detail +{ + +struct WaitAllAwaitable +{ + WaitAllAwaitable() = delete; + + template + WaitAllAwaitable(Awaitables&&... args) + { + (register_awaitable(std::forward(args)), ...); + } + + ~WaitAllAwaitable() + { + if (continuation) + this_coro::scheduler().remove(continuation); + } + + template + void register_awaitable(Awaitable&& task) + { + auto t = run_awaitable(task); + this_coro::scheduler().schedule(t.get_handle(), 0); + t.detach(); + } + + template + Task<> run_awaitable(Awaitable&& awaitable) + { + start_awaitable(); + co_await awaitable; + awaitable_completed(); + co_return; + } + + void start_awaitable() + { + pending++; + } + + void awaitable_completed() + { + pending--; + if (pending == 0 && continuation) + this_coro::scheduler().schedule(continuation, 0); + } + + auto operator co_await() noexcept + { + struct Awaitable + { + WaitAllAwaitable* wait_all; + + bool await_ready() { return false; } + + void await_suspend(std::coroutine_handle<> h) + { + wait_all->continuation = h; + } + + void await_resume() + { } + }; + + return Awaitable{this}; + } + + std::coroutine_handle<> continuation{}; + int pending = 0; +}; + +template +struct AwaitableResumeType +{ + using value_type = std::optional; +}; + +template +struct WaitFirstAwaitable +{ + WaitFirstAwaitable() = delete; + + WaitFirstAwaitable(Awaitables&&... args) + : WaitFirstAwaitable(std::index_sequence_for(), std::forward(args)...) + {} + + template + WaitFirstAwaitable(std::index_sequence, Awaitables&&... args) + { + ((std::get(result) = std::nullopt), ...); + pending = true; + (register_awaitable(std::forward(args)), ...); + } + + ~WaitFirstAwaitable() + { + if (continuation) + this_coro::scheduler().remove(continuation); + } + + template + void register_awaitable(Awaitable&& task) + { + auto t = run_awaitable(task); + this_coro::scheduler().schedule(t.get_handle(), 0); + coroutines[I] = t.get_handle(); + t.detach(); + } + + template + Task<> run_awaitable(Awaitable&& awaitable) + { + auto val = co_await awaitable; + awaitable_completed(std::move(val)); + co_return; + } + + template + void awaitable_completed(auto&& val) + { + if (pending) + { + pending = false; + for (auto j = 0; j < coroutines.size(); j++) + { + if (j != I) + { + this_coro::scheduler().remove(coroutines[j]); + // @todo: also clean up a related poller, if necessary. + coroutines[j].destroy(); + } + } + + std::get(result) = val; + this_coro::scheduler().schedule(continuation, 0); + } + } + + auto operator co_await() noexcept + { + struct Awaitable + { + WaitFirstAwaitable* wait_first; + + bool await_ready() { return false; } + + void await_suspend(std::coroutine_handle<> h) + { + wait_first->continuation = h; + } + + auto await_resume() + { + return std::move(wait_first->result); + } + }; + + return Awaitable{this}; + } + + std::array, sizeof...(Awaitables)> coroutines; + std::tuple::value_type...> result; + std::coroutine_handle<> continuation{}; + int pending = 0; +}; + +} + +template +auto wait_all(Awaitables&&... args) +{ + return detail::WaitAllAwaitable(std::forward(args)...); +} + +template +auto wait_first(Awaitables&&... args) +{ + return detail::WaitFirstAwaitable(std::forward(args)...); +} + +}