Skip to content

Commit 66c577d

Browse files
authored
Merge pull request #25 from RustedBytes/enhance-tasks
fix tasks implementation and reduce monkey patching
2 parents 119d2eb + f603dc5 commit 66c577d

5 files changed

Lines changed: 225 additions & 102 deletions

File tree

python/rsloop/__init__.py

Lines changed: 3 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import io as __io
1212

1313
__DLL_DIR_HANDLES: list[object] = []
14-
__TASK_KWARGS: set[str] | None = None
1514

1615

1716
def __configure_windows_dll_search_path() -> None:
@@ -161,26 +160,6 @@ def profile() -> __typing.Iterator[None]:
161160
_os = __os
162161

163162

164-
def __task_kwargs() -> set[str]:
165-
global __TASK_KWARGS
166-
if __TASK_KWARGS is not None:
167-
return __TASK_KWARGS
168-
169-
import inspect as __inspect
170-
171-
try:
172-
signature = __inspect.signature(__asyncio.Task)
173-
__TASK_KWARGS = {
174-
name
175-
for name, parameter in signature.parameters.items()
176-
if parameter.kind is __inspect.Parameter.KEYWORD_ONLY
177-
}
178-
except (TypeError, ValueError):
179-
__TASK_KWARGS = {"loop", "name"}
180-
181-
return __TASK_KWARGS
182-
183-
184163
def __get_asyncgen_state(loop: Loop) -> dict[str, object]:
185164
state = __ASYNCGEN_STATE.get(loop)
186165
if state is None:
@@ -286,42 +265,6 @@ def __loop_close(self):
286265
__LOOP_CONFIG.pop(self, None)
287266

288267

289-
def __loop_create_task(self, coro, *, name=None, context=None, **kwargs):
290-
task_kwargs_supported = __task_kwargs()
291-
factory = self.get_task_factory()
292-
if factory is not None:
293-
factory_kwargs = dict(kwargs)
294-
factory_kwargs["name"] = name
295-
if context is not None:
296-
factory_kwargs["context"] = context
297-
task = factory(self, coro, **factory_kwargs)
298-
return task
299-
300-
task_kwargs = {}
301-
extra_kwargs = dict(kwargs)
302-
303-
if "name" in task_kwargs_supported:
304-
task_kwargs["name"] = name
305-
if context is not None and "context" in task_kwargs_supported:
306-
task_kwargs["context"] = context
307-
if "eager_start" in extra_kwargs:
308-
eager_start = extra_kwargs.pop("eager_start")
309-
if "eager_start" in task_kwargs_supported:
310-
task_kwargs["eager_start"] = eager_start
311-
312-
if extra_kwargs:
313-
unexpected = next(iter(extra_kwargs))
314-
raise TypeError(
315-
f"create_task() got an unexpected keyword argument {unexpected!r}"
316-
)
317-
318-
task = __asyncio.Task(coro, loop=self, **task_kwargs)
319-
source_traceback = getattr(task, "_source_traceback", None)
320-
if source_traceback:
321-
del source_traceback[-1]
322-
return task
323-
324-
325268
def __cancel_all_tasks(loop: Loop) -> None:
326269
to_cancel = __asyncio.all_tasks(loop)
327270
if not to_cancel:
@@ -1597,8 +1540,9 @@ async def __loop_create_connection(
15971540
if Loop.close is __ORIG_CLOSE:
15981541
Loop.close = __loop_close
15991542

1600-
if Loop.create_task is __ORIG_CREATE_TASK:
1601-
Loop.create_task = __loop_create_task
1543+
# Keep the Rust implementation on the hot path. It already handles task
1544+
# factories and keyword forwarding, while the Python wrapper adds measurable
1545+
# overhead in task-heavy workloads.
16021546

16031547

16041548
def __install_ssl_tracking() -> None:

src/loop_core.rs

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,40 @@ impl LoopCore {
282282
captured,
283283
context_needs_run,
284284
));
285-
self.send_command(LoopCommand::ScheduleReady(Arc::clone(&ready)))
286-
.map_err(|err| pyo3::exceptions::PyRuntimeError::new_err(err.to_string()))?;
285+
286+
match self.try_enqueue_local_ready(ReadyItem::Callback(Arc::clone(&ready))) {
287+
Ok(()) => return Ok(ready),
288+
Err(ReadyItem::Callback(callback)) => {
289+
if self
290+
.try_enqueue_active_ready(ReadyItem::Callback(callback))
291+
.is_ok()
292+
{
293+
return Ok(ready);
294+
}
295+
}
296+
Err(
297+
ReadyItem::Stop
298+
| ReadyItem::FutureSetResult { .. }
299+
| ReadyItem::FutureSetException { .. }
300+
| ReadyItem::StreamTransportRead(_)
301+
| ReadyItem::ProcessTransport(_),
302+
) => unreachable!("schedule_callback only enqueues callback ready items"),
303+
}
304+
305+
self.command_tx
306+
.send(LoopCommand::ScheduleReady(Arc::clone(&ready)))
307+
.map_err(|_| {
308+
pyo3::exceptions::PyRuntimeError::new_err(LoopCoreError::ChannelClosed.to_string())
309+
})?;
310+
#[cfg(target_os = "linux")]
311+
if let Some(waker) = self
312+
.runtime_waker
313+
.lock()
314+
.expect("poisoned runtime waker")
315+
.as_ref()
316+
{
317+
waker.wake_by_ref();
318+
}
287319
Ok(ready)
288320
}
289321

@@ -372,19 +404,23 @@ impl LoopCore {
372404
let mut processed_since_refill = 0_usize;
373405
loop {
374406
if ready_batch.is_empty() || processed_since_refill >= READY_DRAIN_SLICE {
375-
let mut pending = pending_ready.lock().expect("poisoned pending ready queue");
376-
if !pending.is_empty() {
377-
if ready_batch.is_empty() {
378-
std::mem::swap(&mut ready_batch, &mut *pending);
379-
} else {
380-
pending.append(&mut ready_batch);
381-
std::mem::swap(&mut ready_batch, &mut *pending);
407+
let should_check_pending =
408+
ready_batch.is_empty() || wake_pending.load(Ordering::Acquire);
409+
if should_check_pending {
410+
let mut pending =
411+
pending_ready.lock().expect("poisoned pending ready queue");
412+
if !pending.is_empty() {
413+
if ready_batch.is_empty() {
414+
std::mem::swap(&mut ready_batch, &mut *pending);
415+
} else {
416+
pending.append(&mut ready_batch);
417+
std::mem::swap(&mut ready_batch, &mut *pending);
418+
}
419+
}
420+
if pending.is_empty() {
421+
wake_pending.store(false, Ordering::Release);
382422
}
383423
}
384-
if pending.is_empty() {
385-
wake_pending.store(false, Ordering::Release);
386-
}
387-
drop(pending);
388424

389425
// Prioritize cross-thread wakeups such as signals and transport
390426
// connection_lost notifications so they cannot be starved by a

src/python_api.rs

Lines changed: 150 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ use crate::tls::{client_tls_settings, server_tls_settings};
4343
static ASYNCIO_TASK_CLS: OnceLock<Py<PyAny>> = OnceLock::new();
4444
static ASYNCIO_FUTURE_CLS: OnceLock<Py<PyAny>> = OnceLock::new();
4545
static ASYNCIO_GET_RUNNING_LOOP_FN: OnceLock<Py<PyAny>> = OnceLock::new();
46+
static ASYNCIO_TASK_KWARG_SUPPORT: OnceLock<TaskKwargSupport> = OnceLock::new();
4647
#[cfg(any(Py_3_12, all(Py_3_11, not(Py_LIMITED_API))))]
4748
static ASYNCIO_FUTURE_LOOP_KWNAMES: OnceLock<Py<PyTuple>> = OnceLock::new();
4849
#[cfg(any(Py_3_12, all(Py_3_11, not(Py_LIMITED_API))))]
@@ -56,6 +57,12 @@ static ASYNCIO_TASK_LOOP_NAME_CONTEXT_KWNAMES: OnceLock<Py<PyTuple>> = OnceLock:
5657

5758
type ResolvedStreamAddrinfo = (i32, i32, i32, Py<PyAny>);
5859

60+
struct TaskKwargSupport {
61+
name: bool,
62+
context: bool,
63+
eager_start: bool,
64+
}
65+
5966
struct TcpServerSocketOptions {
6067
family: i32,
6168
flags: i32,
@@ -151,6 +158,52 @@ fn asyncio_get_running_loop_fn(py: Python<'_>) -> PyResult<&Py<PyAny>> {
151158
Ok(ASYNCIO_GET_RUNNING_LOOP_FN.get_or_init(|| loaded))
152159
}
153160

161+
fn asyncio_task_kwarg_support(py: Python<'_>) -> PyResult<&'static TaskKwargSupport> {
162+
if let Some(cached) = ASYNCIO_TASK_KWARG_SUPPORT.get() {
163+
return Ok(cached);
164+
}
165+
166+
let inspect = py.import("inspect")?;
167+
let signature = match inspect
168+
.getattr("signature")?
169+
.call1((asyncio_task_cls(py)?.clone_ref(py),))
170+
{
171+
Ok(signature) => signature,
172+
Err(_) => {
173+
return Ok(ASYNCIO_TASK_KWARG_SUPPORT.get_or_init(|| TaskKwargSupport {
174+
name: true,
175+
context: false,
176+
eager_start: false,
177+
}));
178+
}
179+
};
180+
let parameters = signature.getattr("parameters")?;
181+
let keyword_only = inspect.getattr("Parameter")?.getattr("KEYWORD_ONLY")?;
182+
let mut support = TaskKwargSupport {
183+
name: false,
184+
context: false,
185+
eager_start: false,
186+
};
187+
188+
for kwarg_name in ["name", "context", "eager_start"] {
189+
let parameter = match parameters.get_item(kwarg_name) {
190+
Ok(parameter) => parameter,
191+
Err(_) => continue,
192+
};
193+
if !parameter.getattr("kind")?.eq(&keyword_only)? {
194+
continue;
195+
}
196+
match kwarg_name {
197+
"name" => support.name = true,
198+
"context" => support.context = true,
199+
"eager_start" => support.eager_start = true,
200+
_ => {}
201+
}
202+
}
203+
204+
Ok(ASYNCIO_TASK_KWARG_SUPPORT.get_or_init(|| support))
205+
}
206+
154207
#[inline]
155208
fn call_callable_noargs(py: Python<'_>, callable: &Py<PyAny>) -> PyResult<Py<PyAny>> {
156209
unsafe { Bound::from_owned_ptr_or_err(py, ffi::compat::PyObject_CallNoArgs(callable.as_ptr())) }
@@ -330,6 +383,22 @@ fn create_asyncio_task_with_kwargs(
330383
asyncio_task_cls(py)?.call(py, (coro,), Some(&task_kwargs))
331384
}
332385

386+
fn trim_task_source_traceback(py: Python<'_>, task: &Py<PyAny>) -> PyResult<()> {
387+
let Ok(source_traceback) = task.getattr(py, "_source_traceback") else {
388+
return Ok(());
389+
};
390+
if source_traceback.is_none(py) {
391+
return Ok(());
392+
}
393+
394+
let source_traceback = source_traceback.bind(py);
395+
if source_traceback.len()? == 0 {
396+
return Ok(());
397+
}
398+
399+
source_traceback.del_item(source_traceback.len()? - 1)
400+
}
401+
333402
fn call_protocol_factory(
334403
py: Python<'_>,
335404
loop_obj: &Py<PyAny>,
@@ -1298,22 +1367,14 @@ impl PyLoop {
12981367
) -> PyResult<Py<PyAny>> {
12991368
let core = Arc::clone(&slf.borrow(py).core);
13001369
let loop_obj = Self::as_py_any(py, &slf);
1301-
let task_kwargs = PyDict::new(py);
1302-
if let Some(kwargs_in) = kwargs.as_ref() {
1303-
for (key, value) in kwargs_in.bind(py).iter() {
1304-
task_kwargs.set_item(key, value)?;
1305-
}
1306-
}
1307-
if let Some(name) = name.as_ref() {
1308-
task_kwargs.set_item("name", name)?;
1309-
}
1310-
if let Some(context) = context.as_ref() {
1311-
task_kwargs.set_item("context", context)?;
1312-
}
1313-
if let Some(eager_start) = eager_start {
1314-
task_kwargs.set_item("eager_start", eager_start)?;
1315-
}
1316-
let has_kwargs = !task_kwargs.is_empty();
1370+
let task_kwarg_support = asyncio_task_kwarg_support(py)?;
1371+
let extra_kwargs = kwargs
1372+
.as_ref()
1373+
.is_some_and(|kwargs| !kwargs.bind(py).is_empty());
1374+
let has_kwargs = extra_kwargs
1375+
|| name.is_some()
1376+
|| (context.is_some() && task_kwarg_support.context)
1377+
|| (eager_start.is_some() && task_kwarg_support.eager_start);
13171378

13181379
if !core.has_task_factory() && !has_kwargs && core.on_runtime_thread() {
13191380
return create_asyncio_task_for_running_loop(py, coro);
@@ -1329,23 +1390,88 @@ impl PyLoop {
13291390
} else {
13301391
None
13311392
};
1393+
1394+
if task_factory.is_none() && extra_kwargs {
1395+
let unexpected = kwargs
1396+
.as_ref()
1397+
.and_then(|kwargs| kwargs.bind(py).iter().next().map(|(key, _)| key))
1398+
.expect("non-empty kwargs when extra_kwargs is true");
1399+
let unexpected = unexpected.repr()?.extract::<String>()?;
1400+
return Err(PyTypeError::new_err(format!(
1401+
"create_task() got an unexpected keyword argument {unexpected}"
1402+
)));
1403+
}
1404+
1405+
let task_kwargs = if has_kwargs || task_factory.is_some() {
1406+
let task_kwargs = PyDict::new(py);
1407+
if let Some(kwargs_in) = kwargs.as_ref() {
1408+
for (key, value) in kwargs_in.bind(py).iter() {
1409+
task_kwargs.set_item(key, value)?;
1410+
}
1411+
}
1412+
if task_factory.is_some() {
1413+
let factory_name = name
1414+
.as_ref()
1415+
.map(|name| name.clone_ref(py))
1416+
.unwrap_or_else(|| py.None());
1417+
task_kwargs.set_item("name", factory_name)?;
1418+
} else if task_kwarg_support.name {
1419+
if let Some(name) = name.as_ref() {
1420+
task_kwargs.set_item("name", name)?;
1421+
}
1422+
}
1423+
if let Some(context) = context.as_ref() {
1424+
if task_factory.is_some() || task_kwarg_support.context {
1425+
task_kwargs.set_item("context", context)?;
1426+
}
1427+
}
1428+
if let Some(eager_start) = eager_start {
1429+
if task_factory.is_some() || task_kwarg_support.eager_start {
1430+
task_kwargs.set_item("eager_start", eager_start)?;
1431+
}
1432+
}
1433+
Some(task_kwargs)
1434+
} else {
1435+
None
1436+
};
1437+
13321438
if let Some(factory) = task_factory {
1333-
let created = factory.call(py, (loop_obj.clone_ref(py), coro), Some(&task_kwargs))?;
1439+
let created = factory.call(py, (loop_obj.clone_ref(py), coro), task_kwargs.as_ref())?;
13341440
return Ok(created);
13351441
}
13361442

1443+
let trim_source_traceback = core.get_debug();
13371444
if is_current_running_loop(py, &loop_obj)? {
1338-
if !has_kwargs {
1339-
return create_asyncio_task_for_running_loop(py, coro);
1445+
let created = if !has_kwargs {
1446+
create_asyncio_task_for_running_loop(py, coro)?
1447+
} else {
1448+
create_asyncio_task_with_kwargs(
1449+
py,
1450+
None,
1451+
coro,
1452+
task_kwargs.as_ref().expect("task kwargs"),
1453+
)?
1454+
};
1455+
if trim_source_traceback {
1456+
trim_task_source_traceback(py, &created)?;
13401457
}
1341-
return create_asyncio_task_with_kwargs(py, None, coro, &task_kwargs);
1458+
return Ok(created);
13421459
}
13431460

1344-
if !has_kwargs {
1345-
return create_asyncio_task_for_loop(py, &loop_obj, coro, name, context);
1461+
let created = if !has_kwargs {
1462+
create_asyncio_task_for_loop(py, &loop_obj, coro, name, context)?
1463+
} else {
1464+
create_asyncio_task_with_kwargs(
1465+
py,
1466+
Some(&loop_obj),
1467+
coro,
1468+
task_kwargs.as_ref().expect("task kwargs"),
1469+
)?
1470+
};
1471+
if trim_source_traceback {
1472+
trim_task_source_traceback(py, &created)?;
13461473
}
1347-
1348-
create_asyncio_task_with_kwargs(py, Some(&loop_obj), coro, &task_kwargs)
1474+
Ok(created)
13491475
}
13501476

13511477
fn set_task_factory(&self, factory: Option<Py<PyAny>>) {

0 commit comments

Comments
 (0)