Skip to content

Commit 83e523d

Browse files
Break reference cycles in context chain
1 parent 8194fe0 commit 83e523d

1 file changed

Lines changed: 188 additions & 8 deletions

File tree

src/err/err_state.rs

Lines changed: 188 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use crate::{
1616
#[cfg(Py_3_12)]
1717
use {
1818
crate::types::{PyString, PyTuple},
19-
std::ptr::NonNull,
19+
std::ops::ControlFlow::{Break, Continue},
2020
};
2121

2222
pub(crate) struct PyErrState {
@@ -427,9 +427,36 @@ fn create_normalized_exception<'py>(
427427
unsafe { ptype.cast_unchecked() }.clone()
428428
};
429429

430+
let mut current_handled_exception: Option<Bound<'_, PyBaseException>> = unsafe {
431+
ffi::PyErr_GetHandledException()
432+
.assume_owned_or_opt(py)
433+
.map(|obj| obj.cast_into_unchecked())
434+
};
435+
430436
let pvalue = if pvalue.is_exact_instance(&ptype) {
431437
// Safety: already an exception value of the correct type
432-
Ok(unsafe { pvalue.cast_into_unchecked::<PyBaseException>() })
438+
let exc = unsafe { pvalue.cast_into_unchecked::<PyBaseException>() };
439+
440+
if current_handled_exception
441+
.as_ref()
442+
.map(|current| current.is(&exc))
443+
.unwrap_or_default()
444+
{
445+
// Current exception is the same as it's context so do not set the context to avoid a loop
446+
current_handled_exception = None;
447+
} else if let Some(current_context) = current_handled_exception.as_ref() {
448+
// Check if this exception is already in the context chain, so we do not create reference cycles in the context chain.
449+
let mut iter = context_chain_iter(current_context.clone()).peekable();
450+
while let Some((current, next)) = iter.next().zip(iter.peek()) {
451+
if next.is(&exc) {
452+
// Loop in context chain, breaking the loop by not pointing to exc
453+
unsafe { ffi::PyException_SetContext(current.as_ptr(), std::ptr::null_mut()) };
454+
break;
455+
}
456+
}
457+
}
458+
459+
Ok(exc)
433460
} else if pvalue.is_none() {
434461
// None -> no arguments
435462
ptype.call0().and_then(|pvalue| Ok(pvalue.cast_into()?))
@@ -445,19 +472,67 @@ fn create_normalized_exception<'py>(
445472

446473
match pvalue {
447474
Ok(pvalue) => {
448-
unsafe {
449-
if let Some(context) = NonNull::new(ffi::PyErr_GetHandledException()) {
450-
ffi::PyException_SetContext(pvalue.as_ptr(), context.as_ptr())
451-
}
452-
};
475+
// Implicitly set the context of the new exception to the currently handled exception, if any.
476+
if let Some(context) = current_handled_exception {
477+
unsafe { ffi::PyException_SetContext(pvalue.as_ptr(), context.into_ptr()) };
478+
}
453479
pvalue
454480
}
455481
Err(e) => e.value(py).clone(),
456482
}
457483
}
458484

485+
/// Iterates through the context chain of exceptions, starting from `start`, and yields each exception in the chain.
486+
/// When there is a loop in the chain it may yield some elements multiple times, but it will always terminate.
487+
#[inline]
488+
#[cfg(Py_3_12)]
489+
fn context_chain_iter(
490+
start: Bound<'_, PyBaseException>,
491+
) -> impl Iterator<Item = Bound<'_, PyBaseException>> {
492+
#[inline]
493+
fn get_next<'py>(current: &Bound<'py, PyBaseException>) -> Option<Bound<'py, PyBaseException>> {
494+
unsafe {
495+
ffi::PyException_GetContext(current.as_ptr())
496+
.assume_owned_or_opt(current.py())
497+
.map(|obj| obj.cast_into_unchecked())
498+
}
499+
}
500+
501+
let mut slow = None;
502+
let mut current = Some(start);
503+
let mut slow_update_toggle = false;
504+
505+
std::iter::from_fn(move || {
506+
let next = get_next(current.as_ref()?);
507+
508+
// Detect loops in the context chain using Floyd's Tortoise and Hare algorithm.
509+
if let Some((current_slow, current_fast)) = slow.as_ref().zip(next.as_ref()) {
510+
if current_fast.is(current_slow) {
511+
// Loop detected
512+
return current.take();
513+
}
514+
515+
// Every second iteration, advance the slow pointer by one step
516+
if slow_update_toggle {
517+
slow = get_next(current_slow);
518+
}
519+
520+
slow_update_toggle = !slow_update_toggle;
521+
}
522+
523+
// Set the slow pointer after the first iteration
524+
if slow.is_none() {
525+
slow = current.clone()
526+
}
527+
528+
std::mem::replace(&mut current, next)
529+
})
530+
}
531+
459532
#[cfg(test)]
460533
mod tests {
534+
#[cfg(Py_3_12)]
535+
use crate::{exceptions::PyBaseException, ffi, Bound};
461536
use crate::{
462537
exceptions::PyValueError, sync::PyOnceLock, Py, PyAny, PyErr, PyErrArguments, Python,
463538
};
@@ -571,7 +646,10 @@ mod tests {
571646
Bound,
572647
};
573648

574-
fn test_exception<'py>(ptype: &Bound<'py, PyAny>, pvalue: Bound<'py, PyAny>) {
649+
fn test_exception<'py>(
650+
ptype: &Bound<'py, PyAny>,
651+
pvalue: Bound<'py, PyAny>,
652+
) -> (PyErr, PyErr) {
575653
let py = ptype.py();
576654

577655
let exc1 = super::create_normalized_exception(ptype, pvalue.clone());
@@ -592,6 +670,15 @@ mod tests {
592670
assert!(err1.traceback(py).xor(err2.traceback(py)).is_none());
593671
assert!(err1.cause(py).xor(err2.cause(py)).is_none());
594672
assert_eq!(err1.to_string(), err2.to_string());
673+
674+
super::context_chain_iter(err1.value(py).clone())
675+
.zip(super::context_chain_iter(err2.value(py).clone()))
676+
.for_each(|(context1, context2)| {
677+
assert!(context1.get_type().is(context2.get_type()));
678+
assert_eq!(context1.to_string(), context2.to_string());
679+
});
680+
681+
(err1, err2)
595682
}
596683

597684
Python::attach(|py| {
@@ -614,6 +701,99 @@ mod tests {
614701
.into_any()
615702
.into_bound(py),
616703
);
704+
705+
// Loop where err is not part of the loop
706+
let looped_context = create_loop(py, 3);
707+
let err = PyRuntimeError::new_err("Boom");
708+
with_handled_exception(looped_context.value(py), || {
709+
let (normalized, _) = test_exception(
710+
&PyRuntimeError::type_object(py),
711+
err.value(py).clone().into_any(),
712+
);
713+
714+
assert!(normalized
715+
.context(py)
716+
.unwrap()
717+
.value(py)
718+
.is(looped_context.value(py)));
719+
});
720+
721+
// loop where err is part of the loop
722+
let err_a = PyRuntimeError::new_err("A");
723+
let err_b = PyRuntimeError::new_err("B");
724+
// a -> b -> a
725+
err_a.set_context(py, Some(err_b.clone_ref(py)));
726+
err_b.set_context(py, Some(err_a.clone_ref(py)));
727+
// handled = raised = a
728+
with_handled_exception(err_a.value(py), || {
729+
let (rust_normal, py_normal) = test_exception(
730+
&PyRuntimeError::type_object(py),
731+
err_a.value(py).clone().into_any(),
732+
);
733+
734+
// a.context -> b
735+
assert!(rust_normal
736+
.context(py)
737+
.unwrap()
738+
.value(py)
739+
.is(err_b.value(py)));
740+
assert!(py_normal.context(py).unwrap().value(py).is(err_b.value(py)));
741+
});
742+
743+
// no loop yet, but implicit context will loop if we set a.context = b
744+
let err_a = PyRuntimeError::new_err("A");
745+
let err_b = PyRuntimeError::new_err("B");
746+
err_b.set_context(py, Some(err_a.clone_ref(py)));
747+
// raised = a, handled = b
748+
with_handled_exception(err_b.value(py), || {
749+
test_exception(
750+
&PyRuntimeError::type_object(py),
751+
err_b.value(py).clone().into_any(),
752+
);
753+
});
754+
})
755+
}
756+
757+
#[cfg(Py_3_12)]
758+
fn with_handled_exception(exc: &Bound<'_, PyBaseException>, f: impl FnOnce()) {
759+
struct Guard;
760+
impl Drop for Guard {
761+
fn drop(&mut self) {
762+
unsafe { ffi::PyErr_SetHandledException(std::ptr::null_mut()) };
763+
}
764+
}
765+
766+
let guard = Guard;
767+
unsafe { ffi::PyErr_SetHandledException(exc.as_ptr()) };
768+
f();
769+
drop(guard);
770+
}
771+
772+
#[cfg(Py_3_12)]
773+
fn create_loop(py: Python<'_>, size: usize) -> PyErr {
774+
let first = PyValueError::new_err("exc0");
775+
let last = (1..size).fold(first.clone_ref(py), |prev, i| {
776+
let exc = PyValueError::new_err(format!("exc{i}"));
777+
prev.set_context(py, Some(exc.clone_ref(py)));
778+
exc
779+
});
780+
last.set_context(py, Some(first.clone_ref(py)));
781+
782+
first
783+
}
784+
785+
#[test]
786+
#[cfg(Py_3_12)]
787+
fn test_context_chain_iter_terminates() {
788+
Python::attach(|py| {
789+
for size in 1..=8 {
790+
let chain = create_loop(py, size);
791+
let count = super::context_chain_iter(chain.into_value(py).into_bound(py)).count();
792+
assert!(
793+
count >= size,
794+
"We should have seen each element at least once"
795+
);
796+
}
617797
})
618798
}
619799
}

0 commit comments

Comments
 (0)