|
| 1 | +// SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
| 2 | +// |
| 3 | +// SPDX-License-Identifier: Apache-2.0 |
| 4 | + |
| 5 | +#include "py.h" |
| 6 | +#include "vec.h" |
| 7 | + |
| 8 | +#include <frameobject.h> // For PyFrame_Check on Python 3.10 |
| 9 | + |
| 10 | + |
| 11 | +static void best_effort_cleanup_on_internal_error(Vec<PyPtr>& stack) { |
| 12 | + if (stack.empty()) |
| 13 | + return; |
| 14 | + CHECK(PyErr_Occurred()); |
| 15 | + PyPtr frame = try_getattr(stack.back(), "cr_frame"); |
| 16 | + if (frame && PyFrame_Check(frame.get())) |
| 17 | + PyTraceBack_Here(reinterpret_cast<PyFrameObject*>(frame.get())); |
| 18 | + |
| 19 | + ErrorGuard guard; |
| 20 | + while (!stack.empty()) { |
| 21 | + PyPtr coro = stack.back(); |
| 22 | + stack.pop_back(); |
| 23 | + Py_XDECREF(PyObject_CallMethod(coro.get(), "close", "")); |
| 24 | + if (PyErr_Occurred()) |
| 25 | + PyErr_Print(); |
| 26 | + } |
| 27 | +} |
| 28 | + |
| 29 | + |
| 30 | +// Run a coroutine using a software stack to bypass the Python's recursion limit. |
| 31 | +// Use resume_after() to break the call chain and push a new frame to the software stack. |
| 32 | +static PyObject* run_coroutine(PyObject* self, PyObject* main_coro) { |
| 33 | + if (!PyCoro_CheckExact(main_coro)) { |
| 34 | + raise(PyExc_TypeError, "Expected a coroutine"); |
| 35 | + return nullptr; |
| 36 | + } |
| 37 | + |
| 38 | + Vec<PyPtr> stack; |
| 39 | + stack.push_back(newref(main_coro)); |
| 40 | + |
| 41 | + PyPtr ret = newref(Py_None); |
| 42 | + SavedException exc; |
| 43 | + |
| 44 | + while (!stack.empty()) { |
| 45 | + PyObject* coro = stack.back().get(); |
| 46 | + if (!exc) { |
| 47 | + // Happy path: use PyIter_Send() C API for efficiency |
| 48 | + PyObject* res = nullptr; |
| 49 | + PySendResult send_res = PyIter_Send(coro, ret.get(), &res); |
| 50 | + if (send_res == PYGEN_RETURN) { |
| 51 | + ret = steal(res); |
| 52 | + stack.pop_back(); |
| 53 | + } else if (send_res == PYGEN_NEXT) { |
| 54 | + PyPtr next_value = steal(res); |
| 55 | + ret = newref(Py_None); |
| 56 | + if (!PyCoro_CheckExact(next_value.get())) { |
| 57 | + raise(PyExc_TypeError, "Expected a continuation coroutine"); |
| 58 | + best_effort_cleanup_on_internal_error(stack); |
| 59 | + return nullptr; |
| 60 | + } |
| 61 | + stack.push_back(std::move(next_value)); |
| 62 | + } else { |
| 63 | + CHECK(send_res == PYGEN_ERROR); |
| 64 | + exc = save_raised_exception(); |
| 65 | + exc.normalize(); |
| 66 | + ret = newref(Py_None); |
| 67 | + stack.pop_back(); |
| 68 | + } |
| 69 | + continue; |
| 70 | + } |
| 71 | + |
| 72 | + // Slow path: need to call .throw() since there is no public C API to _gen_throw() |
| 73 | + PyPtr continuation = steal(PyObject_CallMethod(coro, "throw", "(O)", exc.value.get())); |
| 74 | + if (continuation) { |
| 75 | + ret = newref(Py_None); |
| 76 | + exc = {}; |
| 77 | + if (!PyCoro_CheckExact(continuation.get())) { |
| 78 | + raise(PyExc_TypeError, "Expected a continuation coroutine"); |
| 79 | + best_effort_cleanup_on_internal_error(stack); |
| 80 | + return nullptr; |
| 81 | + } |
| 82 | + stack.push_back(std::move(continuation)); |
| 83 | + } else { |
| 84 | + CHECK(PyErr_Occurred()); |
| 85 | + exc = save_raised_exception(); |
| 86 | + exc.normalize(); |
| 87 | + CHECK(exc.value); |
| 88 | + if (PyErr_GivenExceptionMatches(exc.value.get(), PyExc_StopIteration)) { |
| 89 | + ret = getattr(exc.value, "value"); |
| 90 | + if (!ret) { |
| 91 | + best_effort_cleanup_on_internal_error(stack); |
| 92 | + return nullptr; |
| 93 | + } |
| 94 | + exc = {}; |
| 95 | + } else { |
| 96 | + ret = newref(Py_None); |
| 97 | + } |
| 98 | + stack.pop_back(); |
| 99 | + } |
| 100 | + } |
| 101 | + |
| 102 | + if (exc) { |
| 103 | + exc.restore(); |
| 104 | + return nullptr; |
| 105 | + } |
| 106 | + |
| 107 | + return ret.release(); |
| 108 | +} |
| 109 | + |
| 110 | + |
| 111 | +static PyMethodDef functions[] = { |
| 112 | + {"run_coroutine", run_coroutine, METH_O, ""}, |
| 113 | + {} |
| 114 | +}; |
| 115 | + |
| 116 | + |
| 117 | +Status coroutine_util_init(PyObject* m) { |
| 118 | + if (PyModule_AddFunctions(m, functions) < 0) |
| 119 | + return ErrorRaised; |
| 120 | + |
| 121 | + return OK; |
| 122 | +} |
| 123 | + |
0 commit comments