Skip to content

Commit 4feb153

Browse files
Break reference cycles in context chain
1 parent 8194fe0 commit 4feb153

File tree

1 file changed

+189
-12
lines changed

1 file changed

+189
-12
lines changed

src/err/err_state.rs

Lines changed: 189 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,15 @@ use std::{
66

77
#[cfg(not(Py_3_12))]
88
use crate::sync::MutexExt;
9+
#[cfg(Py_3_12)]
10+
use crate::types::{PyString, PyTuple};
911
use crate::{
1012
exceptions::{PyBaseException, PyTypeError},
1113
ffi,
1214
ffi_ptr_ext::FfiPtrExt,
1315
types::{PyAnyMethods, PyTraceback, PyType},
1416
Bound, Py, PyAny, PyErrArguments, PyTypeInfo, Python,
1517
};
16-
#[cfg(Py_3_12)]
17-
use {
18-
crate::types::{PyString, PyTuple},
19-
std::ptr::NonNull,
20-
};
2118

2219
pub(crate) struct PyErrState {
2320
// Safety: can only hand out references when in the "normalized" state. Will never change
@@ -427,9 +424,36 @@ fn create_normalized_exception<'py>(
427424
unsafe { ptype.cast_unchecked() }.clone()
428425
};
429426

427+
let mut current_handled_exception: Option<Bound<'_, PyBaseException>> = unsafe {
428+
ffi::PyErr_GetHandledException()
429+
.assume_owned_or_opt(py)
430+
.map(|obj| obj.cast_into_unchecked())
431+
};
432+
430433
let pvalue = if pvalue.is_exact_instance(&ptype) {
431434
// Safety: already an exception value of the correct type
432-
Ok(unsafe { pvalue.cast_into_unchecked::<PyBaseException>() })
435+
let exc = unsafe { pvalue.cast_into_unchecked::<PyBaseException>() };
436+
437+
if current_handled_exception
438+
.as_ref()
439+
.map(|current| current.is(&exc))
440+
.unwrap_or_default()
441+
{
442+
// Current exception is the same as it's context so do not set the context to avoid a loop
443+
current_handled_exception = None;
444+
} else if let Some(current_context) = current_handled_exception.as_ref() {
445+
// Check if this exception is already in the context chain, so we do not create reference cycles in the context chain.
446+
let mut iter = context_chain_iter(current_context.clone()).peekable();
447+
while let Some((current, next)) = iter.next().zip(iter.peek()) {
448+
if next.is(&exc) {
449+
// Loop in context chain, breaking the loop by not pointing to exc
450+
unsafe { ffi::PyException_SetContext(current.as_ptr(), std::ptr::null_mut()) };
451+
break;
452+
}
453+
}
454+
}
455+
456+
Ok(exc)
433457
} else if pvalue.is_none() {
434458
// None -> no arguments
435459
ptype.call0().and_then(|pvalue| Ok(pvalue.cast_into()?))
@@ -445,19 +469,67 @@ fn create_normalized_exception<'py>(
445469

446470
match pvalue {
447471
Ok(pvalue) => {
448-
unsafe {
449-
if let Some(context) = NonNull::new(ffi::PyErr_GetHandledException()) {
450-
ffi::PyException_SetContext(pvalue.as_ptr(), context.as_ptr())
451-
}
452-
};
472+
// Implicitly set the context of the new exception to the currently handled exception, if any.
473+
if let Some(context) = current_handled_exception {
474+
unsafe { ffi::PyException_SetContext(pvalue.as_ptr(), context.into_ptr()) };
475+
}
453476
pvalue
454477
}
455478
Err(e) => e.value(py).clone(),
456479
}
457480
}
458481

482+
/// Iterates through the context chain of exceptions, starting from `start`, and yields each exception in the chain.
483+
/// When there is a loop in the chain it may yield some elements multiple times, but it will always terminate.
484+
#[inline]
485+
#[cfg(Py_3_12)]
486+
fn context_chain_iter(
487+
start: Bound<'_, PyBaseException>,
488+
) -> impl Iterator<Item = Bound<'_, PyBaseException>> {
489+
#[inline]
490+
fn get_next<'py>(current: &Bound<'py, PyBaseException>) -> Option<Bound<'py, PyBaseException>> {
491+
unsafe {
492+
ffi::PyException_GetContext(current.as_ptr())
493+
.assume_owned_or_opt(current.py())
494+
.map(|obj| obj.cast_into_unchecked())
495+
}
496+
}
497+
498+
let mut slow = None;
499+
let mut current = Some(start);
500+
let mut slow_update_toggle = false;
501+
502+
std::iter::from_fn(move || {
503+
let next = get_next(current.as_ref()?);
504+
505+
// Detect loops in the context chain using Floyd's Tortoise and Hare algorithm.
506+
if let Some((current_slow, current_fast)) = slow.as_ref().zip(next.as_ref()) {
507+
if current_fast.is(current_slow) {
508+
// Loop detected
509+
return current.take();
510+
}
511+
512+
// Every second iteration, advance the slow pointer by one step
513+
if slow_update_toggle {
514+
slow = get_next(current_slow);
515+
}
516+
517+
slow_update_toggle = !slow_update_toggle;
518+
}
519+
520+
// Set the slow pointer after the first iteration
521+
if slow.is_none() {
522+
slow = current.clone()
523+
}
524+
525+
std::mem::replace(&mut current, next)
526+
})
527+
}
528+
459529
#[cfg(test)]
460530
mod tests {
531+
#[cfg(Py_3_12)]
532+
use crate::{exceptions::PyBaseException, ffi, Bound};
461533
use crate::{
462534
exceptions::PyValueError, sync::PyOnceLock, Py, PyAny, PyErr, PyErrArguments, Python,
463535
};
@@ -571,7 +643,10 @@ mod tests {
571643
Bound,
572644
};
573645

574-
fn test_exception<'py>(ptype: &Bound<'py, PyAny>, pvalue: Bound<'py, PyAny>) {
646+
fn test_exception<'py>(
647+
ptype: &Bound<'py, PyAny>,
648+
pvalue: Bound<'py, PyAny>,
649+
) -> (PyErr, PyErr) {
575650
let py = ptype.py();
576651

577652
let exc1 = super::create_normalized_exception(ptype, pvalue.clone());
@@ -592,6 +667,15 @@ mod tests {
592667
assert!(err1.traceback(py).xor(err2.traceback(py)).is_none());
593668
assert!(err1.cause(py).xor(err2.cause(py)).is_none());
594669
assert_eq!(err1.to_string(), err2.to_string());
670+
671+
super::context_chain_iter(err1.value(py).clone())
672+
.zip(super::context_chain_iter(err2.value(py).clone()))
673+
.for_each(|(context1, context2)| {
674+
assert!(context1.get_type().is(context2.get_type()));
675+
assert_eq!(context1.to_string(), context2.to_string());
676+
});
677+
678+
(err1, err2)
595679
}
596680

597681
Python::attach(|py| {
@@ -614,6 +698,99 @@ mod tests {
614698
.into_any()
615699
.into_bound(py),
616700
);
701+
702+
// Loop where err is not part of the loop
703+
let looped_context = create_loop(py, 3);
704+
let err = PyRuntimeError::new_err("Boom");
705+
with_handled_exception(looped_context.value(py), || {
706+
let (normalized, _) = test_exception(
707+
&PyRuntimeError::type_object(py),
708+
err.value(py).clone().into_any(),
709+
);
710+
711+
assert!(normalized
712+
.context(py)
713+
.unwrap()
714+
.value(py)
715+
.is(looped_context.value(py)));
716+
});
717+
718+
// loop where err is part of the loop
719+
let err_a = PyRuntimeError::new_err("A");
720+
let err_b = PyRuntimeError::new_err("B");
721+
// a -> b -> a
722+
err_a.set_context(py, Some(err_b.clone_ref(py)));
723+
err_b.set_context(py, Some(err_a.clone_ref(py)));
724+
// handled = raised = a
725+
with_handled_exception(err_a.value(py), || {
726+
let (rust_normal, py_normal) = test_exception(
727+
&PyRuntimeError::type_object(py),
728+
err_a.value(py).clone().into_any(),
729+
);
730+
731+
// a.context -> b
732+
assert!(rust_normal
733+
.context(py)
734+
.unwrap()
735+
.value(py)
736+
.is(err_b.value(py)));
737+
assert!(py_normal.context(py).unwrap().value(py).is(err_b.value(py)));
738+
});
739+
740+
// no loop yet, but implicit context will loop if we set a.context = b
741+
let err_a = PyRuntimeError::new_err("A");
742+
let err_b = PyRuntimeError::new_err("B");
743+
err_b.set_context(py, Some(err_a.clone_ref(py)));
744+
// raised = a, handled = b
745+
with_handled_exception(err_b.value(py), || {
746+
test_exception(
747+
&PyRuntimeError::type_object(py),
748+
err_b.value(py).clone().into_any(),
749+
);
750+
});
751+
})
752+
}
753+
754+
#[cfg(Py_3_12)]
755+
fn with_handled_exception(exc: &Bound<'_, PyBaseException>, f: impl FnOnce()) {
756+
struct Guard;
757+
impl Drop for Guard {
758+
fn drop(&mut self) {
759+
unsafe { ffi::PyErr_SetHandledException(std::ptr::null_mut()) };
760+
}
761+
}
762+
763+
let guard = Guard;
764+
unsafe { ffi::PyErr_SetHandledException(exc.as_ptr()) };
765+
f();
766+
drop(guard);
767+
}
768+
769+
#[cfg(Py_3_12)]
770+
fn create_loop(py: Python<'_>, size: usize) -> PyErr {
771+
let first = PyValueError::new_err("exc0");
772+
let last = (1..size).fold(first.clone_ref(py), |prev, i| {
773+
let exc = PyValueError::new_err(format!("exc{i}"));
774+
prev.set_context(py, Some(exc.clone_ref(py)));
775+
exc
776+
});
777+
last.set_context(py, Some(first.clone_ref(py)));
778+
779+
first
780+
}
781+
782+
#[test]
783+
#[cfg(Py_3_12)]
784+
fn test_context_chain_iter_terminates() {
785+
Python::attach(|py| {
786+
for size in 1..=8 {
787+
let chain = create_loop(py, size);
788+
let count = super::context_chain_iter(chain.into_value(py).into_bound(py)).count();
789+
assert!(
790+
count >= size,
791+
"We should have seen each element at least once"
792+
);
793+
}
617794
})
618795
}
619796
}

0 commit comments

Comments
 (0)