WIP9: some improvements
Some checks failed
CI & Unit Tests / Unit-Tests (push) Failing after 8s
CI & Unit Tests / Docs (push) Successful in 10s

This commit is contained in:
Erki 2025-02-28 08:51:05 +02:00
parent 5e5aee38dc
commit acbdbf4f7f
3 changed files with 81 additions and 38 deletions

View File

@ -322,16 +322,6 @@ TEST_CASE("Signal awaiters work.", "[coro],[signal]")
namespace namespace
{ {
skullc::coro::Task<> test_wait_all()
{
co_await skullc::coro::wait_all(
test_sleepy_coro(0, 10ms),
test_sleepy_coro(1, 20ms)
);
co_return;
}
skullc::coro::Task<int> test_sleepy_coro_return(const int expected, const std::chrono::duration<uint32_t, std::milli>& duration = 10ms) skullc::coro::Task<int> test_sleepy_coro_return(const int expected, const std::chrono::duration<uint32_t, std::milli>& duration = 10ms)
{ {
co_await skullc::coro::sleep(0ms, duration); co_await skullc::coro::sleep(0ms, duration);
@ -340,6 +330,19 @@ skullc::coro::Task<int> test_sleepy_coro_return(const int expected, const std::c
co_return expected; co_return expected;
} }
skullc::coro::Task<> test_wait_all()
{
auto val = co_await skullc::coro::wait_all(
test_sleepy_coro(0, 10ms),
test_sleepy_coro_return(1, 20ms)
);
REQUIRE(std::get<0>(val) == std::monostate{});
REQUIRE(std::get<1>(val) == 1);
co_return;
}
skullc::coro::Task<> testwait_first() skullc::coro::Task<> testwait_first()
{ {
auto val = co_await skullc::coro::wait_first( auto val = co_await skullc::coro::wait_first(

View File

@ -16,14 +16,30 @@ namespace skullc::coro
namespace detail namespace detail
{ {
template<typename Awaitable>
struct AwaitableResumeType
{
using raw_type = typename std::conditional<
std::is_void_v<typename Awaitable::value_type>,
std::monostate,
typename Awaitable::value_type>::type;
using value_type = std::optional<raw_type>;
};
template<typename... Awaitables>
struct WaitAllAwaitable struct WaitAllAwaitable
{ {
WaitAllAwaitable() = delete; WaitAllAwaitable() = delete;
template<typename... Awaitables>
WaitAllAwaitable(Awaitables&&... args) WaitAllAwaitable(Awaitables&&... args)
: WaitAllAwaitable(std::index_sequence_for<Awaitables...>(), std::forward<Awaitables>(args)...)
{}
template<std::size_t... Is>
WaitAllAwaitable(std::index_sequence<Is...>, Awaitables&&... args)
{ {
(register_awaitable(std::forward<Awaitables>(args)), ...); ((std::get<Is>(result) = std::nullopt), ...);
(register_awaitable<Is>(std::forward<Awaitables>(args)), ...);
} }
~WaitAllAwaitable() ~WaitAllAwaitable()
@ -32,35 +48,47 @@ struct WaitAllAwaitable
this_coro::scheduler().remove(continuation); this_coro::scheduler().remove(continuation);
} }
template<typename Awaitable> template<std::size_t I, typename Awaitable>
void register_awaitable(Awaitable&& task) void register_awaitable(Awaitable&& task)
{ {
auto t = run_awaitable(task); start_awaitable();
auto t = run_awaitable<I>(task);
this_coro::scheduler().schedule(t.get_handle(), 0); this_coro::scheduler().schedule(t.get_handle(), 0);
t.detach(); t.detach();
} }
template<typename Awaitable> template<std::size_t I, typename Awaitable>
Task<> run_awaitable(Awaitable&& awaitable) Task<> run_awaitable(Awaitable&& awaitable)
{ {
start_awaitable(); using decayed = std::decay_t<Awaitable>;
if constexpr (std::is_void_v<typename decayed::value_type> == true)
{
co_await awaitable; co_await awaitable;
std::get<I>(result) = std::monostate{};
}
else
{
auto val = co_await awaitable;
std::get<I>(result) = std::move(val);
}
awaitable_completed(); awaitable_completed();
co_return; co_return;
} }
void awaitable_completed()
{
pending--;
if (pending == 0 && continuation)
this_coro::scheduler().schedule(continuation, 0);
}
void start_awaitable() void start_awaitable()
{ {
pending++; pending++;
} }
void awaitable_completed()
{
pending--;
if (pending == 0 && continuation)
this_coro::scheduler().schedule(continuation, 0);
}
auto operator co_await() noexcept auto operator co_await() noexcept
{ {
struct Awaitable struct Awaitable
@ -74,23 +102,20 @@ struct WaitAllAwaitable
wait_all->continuation = h; wait_all->continuation = h;
} }
void await_resume() auto await_resume()
{ } {
return std::move(wait_all->result);
}
}; };
return Awaitable{this}; return Awaitable{this};
} }
std::tuple<typename AwaitableResumeType<Awaitables>::value_type...> result;
std::coroutine_handle<> continuation{}; std::coroutine_handle<> continuation{};
int pending = 0; int pending = 0;
}; };
template<typename Awaitable>
struct AwaitableResumeType
{
using value_type = std::optional<typename Awaitable::value_type>;
};
template<typename... Awaitables> template<typename... Awaitables>
struct WaitFirstAwaitable struct WaitFirstAwaitable
{ {
@ -125,9 +150,18 @@ struct WaitFirstAwaitable
template<std::size_t I, typename Awaitable> template<std::size_t I, typename Awaitable>
Task<> run_awaitable(Awaitable&& awaitable) Task<> run_awaitable(Awaitable&& awaitable)
{
using decayed = std::decay_t<Awaitable>;
if constexpr (std::is_void_v<typename decayed::value_type> == true)
{
co_await awaitable;
awaitable_completed<I>(std::monostate{});
}
else
{ {
auto val = co_await awaitable; auto val = co_await awaitable;
awaitable_completed<I>(std::move(val)); awaitable_completed<I>(std::move(val));
}
co_return; co_return;
} }
@ -142,7 +176,6 @@ struct WaitFirstAwaitable
if (j != I) if (j != I)
{ {
this_coro::scheduler().remove(coroutines[j]); this_coro::scheduler().remove(coroutines[j]);
// @todo: also clean up a related poller, if necessary.
coroutines[j].destroy(); coroutines[j].destroy();
} }
} }
@ -185,7 +218,7 @@ struct WaitFirstAwaitable
template<typename... Awaitables> template<typename... Awaitables>
auto wait_all(Awaitables&&... args) auto wait_all(Awaitables&&... args)
{ {
return detail::WaitAllAwaitable(std::forward<Awaitables>(args)...); return detail::WaitAllAwaitable<Awaitables...>(std::forward<Awaitables>(args)...);
} }
template<typename... Awaitables> template<typename... Awaitables>

View File

@ -94,15 +94,22 @@ public:
bool remove(std::coroutine_handle<> handle) override bool remove(std::coroutine_handle<> handle) override
{ {
bool found = false;
for (auto& p : scheduled_) for (auto& p : scheduled_)
if (p.second == handle) if (p.second == handle)
{ {
std::swap(p, scheduled_.back()); std::swap(p, scheduled_.back());
scheduled_.pop_back(); scheduled_.pop_back();
std::make_heap(scheduled_.begin(), scheduled_.end(), cmp); std::make_heap(scheduled_.begin(), scheduled_.end(), cmp);
return true;
found = true;
break;
} }
return false;
if (pollers_.size() > 0)
pollers_.erase(std::remove_if(pollers_.begin(), pollers_.end(),
[handle](const auto* poller) { return poller->stored_coro == handle; }));
return found;
} }
private: private: