Skip to content

Commit c487ff4

Browse files
committed
Fix PCSX::Coroutine lifetime and chaining
Brings the host-side coroutine in line with psyqo's fixes: proper move semantics, destructor, awaiting coroutine resumption on co_return, and operator co_await for safe chaining. Signed-off-by: Nicolas 'Pixel' Noble <nicolas@nobis-crew.org>
1 parent dc95ead commit c487ff4

1 file changed

Lines changed: 127 additions & 37 deletions

File tree

src/support/coroutine.h

Lines changed: 127 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -26,52 +26,93 @@ SOFTWARE.
2626

2727
#pragma once
2828

29-
#if defined(__APPLE__) && (__clang_major__ < 15)
30-
// Why has Apple become the Microsoft of Software Engineering?
31-
#include <experimental/coroutine>
32-
#else
3329
#include <coroutine>
34-
#endif
3530
#include <type_traits>
31+
#include <utility>
3632

3733
namespace PCSX {
3834

3935
template <typename T = void>
4036
struct Coroutine {
41-
#if defined(__APPLE__) && (__clang_major__ < 15)
42-
template <typename U>
43-
using CoroutineHandle = std::experimental::coroutine_handle<U>;
44-
using CoroutineHandleVoid = std::experimental::coroutine_handle<void>;
45-
#else
46-
template <typename U>
47-
using CoroutineHandle = std::coroutine_handle<U>;
48-
using CoroutineHandleVoid = std::coroutine_handle<void>;
49-
#endif
5037
struct Empty {};
5138
typedef typename std::conditional<std::is_void<T>::value, Empty, T>::type SafeT;
5239

5340
Coroutine() = default;
54-
Coroutine(Coroutine &&other) = default;
55-
Coroutine &operator=(Coroutine &&other) = default;
41+
42+
Coroutine(Coroutine &&other) {
43+
if (m_handle) m_handle.destroy();
44+
m_handle = other.m_handle;
45+
m_value = std::move(other.m_value);
46+
m_suspended = other.m_suspended;
47+
m_earlyResume = other.m_earlyResume;
48+
49+
other.m_handle = nullptr;
50+
other.m_value = SafeT{};
51+
other.m_suspended = true;
52+
other.m_earlyResume = false;
53+
}
54+
55+
Coroutine &operator=(Coroutine &&other) {
56+
if (this != &other) {
57+
if (m_handle) m_handle.destroy();
58+
m_handle = other.m_handle;
59+
m_value = std::move(other.m_value);
60+
m_suspended = other.m_suspended;
61+
m_earlyResume = other.m_earlyResume;
62+
63+
other.m_handle = nullptr;
64+
other.m_value = SafeT{};
65+
other.m_suspended = true;
66+
other.m_earlyResume = false;
67+
}
68+
return *this;
69+
}
70+
5671
Coroutine(Coroutine const &) = delete;
5772
Coroutine &operator=(Coroutine const &) = delete;
73+
~Coroutine() {
74+
if (m_handle) m_handle.destroy();
75+
m_handle = nullptr;
76+
}
5877

5978
struct Awaiter {
60-
constexpr bool await_ready() const noexcept { return false; }
61-
constexpr void await_suspend(CoroutineHandleVoid) const noexcept {}
79+
Awaiter(Awaiter &&other) = default;
80+
Awaiter &operator=(Awaiter &&other) = default;
81+
Awaiter(Awaiter const &) = default;
82+
Awaiter &operator=(Awaiter const &) = default;
83+
constexpr bool await_ready() const noexcept {
84+
bool ret = m_coroutine->m_earlyResume;
85+
m_coroutine->m_earlyResume = false;
86+
return ret;
87+
}
88+
constexpr void await_suspend(std::coroutine_handle<> h) { m_coroutine->m_suspended = true; }
6289
constexpr void await_resume() const noexcept {}
90+
91+
private:
92+
Awaiter(Coroutine *coroutine) : m_coroutine(coroutine) {}
93+
Coroutine *m_coroutine;
94+
friend struct Coroutine;
6395
};
6496

65-
Awaiter awaiter() { return {}; }
97+
Awaiter awaiter() { return Awaiter(this); }
98+
6699
void resume() {
67100
if (!m_handle) return;
101+
if (!m_suspended) {
102+
m_earlyResume = true;
103+
return;
104+
}
105+
m_suspended = false;
68106
m_handle.resume();
69107
}
70108

71109
bool done() {
72110
if (!m_handle) return true;
73111
bool isDone = m_handle.done();
74112
if (isDone) {
113+
if constexpr (!std::is_void<T>::value) {
114+
m_value = std::move(m_handle.promise().m_value);
115+
}
75116
m_handle.destroy();
76117
m_handle = nullptr;
77118
}
@@ -82,38 +123,87 @@ struct Coroutine {
82123

83124
private:
84125
struct PromiseVoid {
85-
Coroutine<T> get_return_object() { return Coroutine{std::move(CoroutineHandle<Promise>::from_promise(*this))}; }
86-
Awaiter initial_suspend() { return {}; }
87-
Awaiter final_suspend() noexcept { return {}; }
126+
Coroutine<> get_return_object() {
127+
return Coroutine<>{std::move(std::coroutine_handle<Promise>::from_promise(*this))};
128+
}
129+
std::suspend_always initial_suspend() { return {}; }
130+
std::suspend_always final_suspend() noexcept { return {}; }
88131
void unhandled_exception() {}
89132
template <typename From>
90-
Awaiter yield_value(From &&from) {
91-
return {};
133+
From yield_value(From &&from) {
134+
return std::forward<From>(from);
135+
}
136+
void return_void() {
137+
if (m_awaitingCoroutine) {
138+
m_awaitingCoroutine.resume();
139+
m_awaitingCoroutine = nullptr;
140+
}
92141
}
93-
void return_void() {}
142+
[[no_unique_address]] Empty m_value;
143+
std::coroutine_handle<> m_awaitingCoroutine;
94144
};
145+
95146
struct PromiseValue {
96-
PromiseValue(Coroutine<T> *c) : coroutine(c) {}
97-
Coroutine<T> get_return_object() { return Coroutine{std::move(CoroutineHandle<Promise>::from_promise(*this))}; }
98-
Awaiter initial_suspend() { return {}; }
99-
Awaiter final_suspend() noexcept { return {}; }
147+
Coroutine<T> get_return_object() {
148+
return Coroutine{std::move(std::coroutine_handle<Promise>::from_promise(*this))};
149+
}
150+
std::suspend_always initial_suspend() { return {}; }
151+
std::suspend_always final_suspend() noexcept { return {}; }
100152
void unhandled_exception() {}
101-
// This should be an std::convertible_to<T>, but Apple still doesn't have a fully C++-20 conformant library.
102153
template <typename From>
103-
Awaiter yield_value(From &&from) {
104-
coroutine->m_value = std::forward<From>(from);
105-
return {};
154+
From yield_value(From &&from) {
155+
return std::forward<From>(from);
156+
}
157+
void return_value(T &&value) {
158+
m_value = std::move(value);
159+
if (m_awaitingCoroutine) {
160+
m_awaitingCoroutine.resume();
161+
m_awaitingCoroutine = nullptr;
162+
}
106163
}
107-
void return_value(T &&value) { coroutine->m_value = std::forward(value); }
108-
Coroutine<T> *coroutine = nullptr;
164+
T m_value;
165+
std::coroutine_handle<> m_awaitingCoroutine;
109166
};
167+
110168
typedef typename std::conditional<std::is_void<T>::value, PromiseVoid, PromiseValue>::type Promise;
111-
Coroutine(CoroutineHandle<Promise> &&handle) : m_handle(std::move(handle)) {}
112-
CoroutineHandle<Promise> m_handle;
169+
170+
Coroutine(std::coroutine_handle<Promise> &&handle) : m_handle(std::move(handle)) {}
171+
172+
std::coroutine_handle<Promise> m_handle;
113173
[[no_unique_address]] SafeT m_value;
174+
bool m_suspended = true;
175+
bool m_earlyResume = false;
114176

115177
public:
116178
using promise_type = Promise;
179+
180+
struct ChainAwaiter {
181+
std::coroutine_handle<Promise> handle;
182+
183+
constexpr bool await_ready() { return handle.done(); }
184+
185+
void await_suspend(std::coroutine_handle<> h) {
186+
handle.promise().m_awaitingCoroutine = h;
187+
if (!handle.done()) handle.resume();
188+
}
189+
190+
constexpr T await_resume() {
191+
if constexpr (std::is_void<T>::value) {
192+
handle.destroy();
193+
return;
194+
} else {
195+
auto val = std::move(handle.promise().m_value);
196+
handle.destroy();
197+
return val;
198+
}
199+
}
200+
};
201+
202+
ChainAwaiter operator co_await() && {
203+
auto h = m_handle;
204+
m_handle = nullptr;
205+
return ChainAwaiter{h};
206+
}
117207
};
118208

119209
} // namespace PCSX

0 commit comments

Comments
 (0)