Skip to content

Commit a69f4fe

Browse files
committed
Reimplement run_coroutine() in C++ to simplify tracebacks
Signed-off-by: Greg Bonik <gbonik@nvidia.com>
1 parent 6d28570 commit a69f4fe

9 files changed

Lines changed: 243 additions & 38 deletions

File tree

cext/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ set(cext_include_dirs
4545
# Build a static library first, so that we could reuse it for several build targets
4646

4747
add_library(_cext_static STATIC
48+
coroutine_util.cpp
4849
cuda_loader.cpp
4950
cuda_helper.cpp
5051
memory.cpp

cext/coroutine_util.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
// SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#include "py.h"
6+
#include "vec.h"
7+
8+
#include <frameobject.h> // For PyFrame_Check on Python 3.10
9+
10+
11+
static void best_effort_cleanup_on_internal_error(Vec<PyPtr>& stack) {
12+
if (stack.empty())
13+
return;
14+
CHECK(PyErr_Occurred());
15+
PyPtr frame = try_getattr(stack.back(), "cr_frame");
16+
if (frame && PyFrame_Check(frame.get()))
17+
PyTraceBack_Here(reinterpret_cast<PyFrameObject*>(frame.get()));
18+
19+
ErrorGuard guard;
20+
while (!stack.empty()) {
21+
PyPtr coro = stack.back();
22+
stack.pop_back();
23+
Py_XDECREF(PyObject_CallMethod(coro.get(), "close", ""));
24+
if (PyErr_Occurred())
25+
PyErr_Print();
26+
}
27+
}
28+
29+
30+
// Run a coroutine using a software stack to bypass the Python's recursion limit.
31+
// Use resume_after() to break the call chain and push a new frame to the software stack.
32+
static PyObject* run_coroutine(PyObject* self, PyObject* main_coro) {
33+
if (!PyCoro_CheckExact(main_coro)) {
34+
raise(PyExc_TypeError, "Expected a coroutine");
35+
return nullptr;
36+
}
37+
38+
Vec<PyPtr> stack;
39+
stack.push_back(newref(main_coro));
40+
41+
PyPtr ret = newref(Py_None);
42+
SavedException exc;
43+
44+
while (!stack.empty()) {
45+
PyObject* coro = stack.back().get();
46+
if (!exc) {
47+
// Happy path: use PyIter_Send() C API for efficiency
48+
PyObject* res = nullptr;
49+
PySendResult send_res = PyIter_Send(coro, ret.get(), &res);
50+
if (send_res == PYGEN_RETURN) {
51+
ret = steal(res);
52+
stack.pop_back();
53+
} else if (send_res == PYGEN_NEXT) {
54+
PyPtr next_value = steal(res);
55+
ret = newref(Py_None);
56+
if (!PyCoro_CheckExact(next_value.get())) {
57+
raise(PyExc_TypeError, "Expected a continuation coroutine");
58+
best_effort_cleanup_on_internal_error(stack);
59+
return nullptr;
60+
}
61+
stack.push_back(std::move(next_value));
62+
} else {
63+
CHECK(send_res == PYGEN_ERROR);
64+
exc = save_raised_exception();
65+
exc.normalize();
66+
ret = newref(Py_None);
67+
stack.pop_back();
68+
}
69+
continue;
70+
}
71+
72+
// Slow path: need to call .throw() since there is no public C API to _gen_throw()
73+
PyPtr continuation = steal(PyObject_CallMethod(coro, "throw", "(O)", exc.value.get()));
74+
if (continuation) {
75+
ret = newref(Py_None);
76+
exc = {};
77+
if (!PyCoro_CheckExact(continuation.get())) {
78+
raise(PyExc_TypeError, "Expected a continuation coroutine");
79+
best_effort_cleanup_on_internal_error(stack);
80+
return nullptr;
81+
}
82+
stack.push_back(std::move(continuation));
83+
} else {
84+
CHECK(PyErr_Occurred());
85+
exc = save_raised_exception();
86+
exc.normalize();
87+
CHECK(exc.value);
88+
if (PyErr_GivenExceptionMatches(exc.value.get(), PyExc_StopIteration)) {
89+
ret = getattr(exc.value, "value");
90+
if (!ret) {
91+
best_effort_cleanup_on_internal_error(stack);
92+
return nullptr;
93+
}
94+
exc = {};
95+
} else {
96+
ret = newref(Py_None);
97+
}
98+
stack.pop_back();
99+
}
100+
}
101+
102+
if (exc) {
103+
exc.restore();
104+
return nullptr;
105+
}
106+
107+
return ret.release();
108+
}
109+
110+
111+
static PyMethodDef functions[] = {
112+
{"run_coroutine", run_coroutine, METH_O, ""},
113+
{}
114+
};
115+
116+
117+
Status coroutine_util_init(PyObject* m) {
118+
if (PyModule_AddFunctions(m, functions) < 0)
119+
return ErrorRaised;
120+
121+
return OK;
122+
}
123+

cext/coroutine_util.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#pragma once
6+
7+
Status coroutine_util_init(PyObject* m);
8+

cext/module.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "tile_kernel.h"
88
#include "cuda_helper.h"
9+
#include "coroutine_util.h"
910
#include "xla_ffi_py.h"
1011

1112
#ifdef _WIN32
@@ -48,6 +49,9 @@ PyMODINIT_FUNC PyInit__cext() {
4849
if (!cuda_helper_init(m.get()))
4950
return nullptr;
5051

52+
if (!coroutine_util_init(m.get()))
53+
return nullptr;
54+
5155
if (!xla_ffi_init(m.get()))
5256
return nullptr;
5357

cext/py.h

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,10 @@ ErrorRaised_t raise(PyObject* exctype, const char* fmt, Args&&... args) {
202202
struct SavedException {
203203
PyPtr type, value, traceback;
204204

205+
operator bool() const {
206+
return bool(type);
207+
}
208+
205209
void normalize() {
206210
PyObject* tmp_type = type.release();
207211
PyObject* tmp_value = value.release();
@@ -210,6 +214,15 @@ struct SavedException {
210214
type = steal(tmp_type);
211215
value = steal(tmp_value);
212216
traceback = steal(tmp_traceback);
217+
if (traceback)
218+
PyException_SetTraceback(value.get(), traceback.get());
219+
}
220+
221+
void restore() {
222+
PyObject* tmp_type = type.release();
223+
PyObject* tmp_value = value.release();
224+
PyObject* tmp_traceback = traceback.release();
225+
PyErr_Restore(tmp_type, tmp_value, tmp_traceback);
213226
}
214227
};
215228

@@ -250,10 +263,7 @@ struct ErrorGuard {
250263
void operator=(const ErrorGuard&) = delete;
251264

252265
~ErrorGuard() {
253-
PyObject* tmp_type = exc.type.release();
254-
PyObject* tmp_value = exc.value.release();
255-
PyObject* tmp_traceback = exc.traceback.release();
256-
PyErr_Restore(tmp_type, tmp_value, tmp_traceback);
266+
exc.restore();
257267
}
258268
};
259269

cext/vec.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,19 @@ class Vec {
158158
return data_[i];
159159
}
160160

161+
T& back() {
162+
return data_[size_ - 1];
163+
}
164+
165+
const T& back() const {
166+
return data_[size_ - 1];
167+
}
168+
169+
void pop_back() {
170+
back().~T();
171+
--size_;
172+
}
173+
161174
bool operator== (const Vec& other) const {
162175
size_t n = size_;
163176
if (n != other.size_) return false;

src/cuda/tile/_cext.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ def _benchmark(stream: int,
105105
/) -> float: ...
106106

107107

108+
def run_coroutine(coro):
109+
"""
110+
Run a coroutine using a software stack to bypass the Python's recursion limit.
111+
Use resume_after() to break the call chain and push a new frame to the software stack.
112+
"""
113+
114+
108115
CU_TENSOR_MAP_DATA_TYPE_UINT8: int
109116
CU_TENSOR_MAP_DATA_TYPE_UINT16: int
110117
CU_TENSOR_MAP_DATA_TYPE_UINT32: int

src/cuda/tile/_coroutine_util.py

Lines changed: 4 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,10 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
import sys
6-
from contextlib import ExitStack
75
from dataclasses import dataclass
86
from typing import Awaitable
97

10-
11-
# Run a coroutine using a software stack to bypass the Python's recursion limit.
12-
# Use resume_after() to break the call chain and push a new frame to the software stack.
13-
def run_coroutine(awaitable: Awaitable):
14-
ret = None
15-
exc_info = None
16-
stack = []
17-
with ExitStack() as es:
18-
try:
19-
stack.append(awaitable.__await__())
20-
while stack:
21-
top = stack[-1]
22-
try:
23-
continuation = top.send(ret) if exc_info is None else top.throw(exc_info[1])
24-
except StopIteration as s:
25-
ret = s.value
26-
exc_info = None
27-
stack.pop()
28-
except Exception:
29-
ret = None
30-
exc_info = sys.exc_info()
31-
stack.pop()
32-
else:
33-
ret = exc_info = None
34-
stack.append(continuation.__await__())
35-
if exc_info is None:
36-
return ret
37-
else:
38-
raise exc_info[1]
39-
finally:
40-
for c in stack:
41-
es.callback(c.close)
8+
from cuda.tile._cext import run_coroutine
429

4310

4411
# Replace `await foo()` with `await resume_after(foo())` to bypass the recursion limit.
@@ -48,3 +15,6 @@ class resume_after:
4815

4916
def __await__(self):
5017
return (yield self.awaitable)
18+
19+
20+
__all__ = ["run_coroutine", "resume_after"]

test/test_coroutine_util.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,32 @@ def test_raise_then_catch():
5252
assert res == 100 + 2 + 3 + 4
5353

5454

55+
async def return_123():
56+
return 123
57+
58+
59+
async def raise_then_catch_and_call_another(n):
60+
if n == 0:
61+
raise ValueError("Hello")
62+
63+
if n == 1:
64+
try:
65+
await resume_after(raise_then_catch_and_call_another(0))
66+
except ValueError as e:
67+
assert str(e) == "Hello"
68+
x = await resume_after(return_123())
69+
return x
70+
assert False
71+
72+
r = await resume_after(raise_then_catch_and_call_another(n - 1))
73+
return r + n
74+
75+
76+
def test_raise_then_catch_and_call_another():
77+
res = run_coroutine(raise_then_catch_and_call_another(4))
78+
assert res == 123 + 2 + 3 + 4
79+
80+
5581
async def two_calls():
5682
t1 = await resume_after(series(3))
5783
t2 = await resume_after(series(4))
@@ -75,5 +101,48 @@ def test_traceback_preserved():
75101
try:
76102
run_coroutine(call_leaf())
77103
except ValueError as e:
104+
traceback.print_tb(e.__traceback__)
78105
frame_names = [f.name for f in traceback.extract_tb(e.__traceback__)]
79106
assert "raise_in_leaf" in frame_names
107+
else:
108+
assert False
109+
110+
111+
class WeirdAwaitable:
112+
def __await__(self):
113+
return iter([123])
114+
115+
116+
async def weird_await():
117+
await WeirdAwaitable()
118+
119+
120+
async def call_weird_await():
121+
await weird_await()
122+
123+
124+
def test_unexpected_awaitable():
125+
try:
126+
run_coroutine(call_weird_await())
127+
except TypeError as e:
128+
assert "Expected a continuation coroutine" in str(e)
129+
traceback.print_tb(e.__traceback__)
130+
frame_names = [f.name for f in traceback.extract_tb(e.__traceback__)]
131+
assert "call_weird_await" in frame_names
132+
else:
133+
assert False
134+
135+
136+
async def resume_after_call_weird_await(flag):
137+
try:
138+
await resume_after(call_weird_await())
139+
finally:
140+
flag[0] = True
141+
142+
143+
def test_cleanup_after_internal_error():
144+
flag = [False]
145+
coro = resume_after_call_weird_await(flag)
146+
with pytest.raises(TypeError, match="Expected a continuation coroutine"):
147+
run_coroutine(coro)
148+
assert flag[0] is True

0 commit comments

Comments
 (0)