@@ -16,7 +16,7 @@ use crate::{
1616#[ cfg( Py_3_12 ) ]
1717use {
1818 crate :: types:: { PyString , PyTuple } ,
19- std:: ptr :: NonNull ,
19+ std:: ops :: ControlFlow :: { Break , Continue } ,
2020} ;
2121
2222pub ( crate ) struct PyErrState {
@@ -427,9 +427,36 @@ fn create_normalized_exception<'py>(
427427 unsafe { ptype. cast_unchecked ( ) } . clone ( )
428428 } ;
429429
430+ let mut current_handled_exception: Option < Bound < ' _ , PyBaseException > > = unsafe {
431+ ffi:: PyErr_GetHandledException ( )
432+ . assume_owned_or_opt ( py)
433+ . map ( |obj| obj. cast_into_unchecked ( ) )
434+ } ;
435+
430436 let pvalue = if pvalue. is_exact_instance ( & ptype) {
431437 // Safety: already an exception value of the correct type
432- Ok ( unsafe { pvalue. cast_into_unchecked :: < PyBaseException > ( ) } )
438+ let exc = unsafe { pvalue. cast_into_unchecked :: < PyBaseException > ( ) } ;
439+
440+ if current_handled_exception
441+ . as_ref ( )
442+ . map ( |current| current. is ( & exc) )
443+ . unwrap_or_default ( )
444+ {
445+ // Current exception is the same as it's context so do not set the context to avoid a loop
446+ current_handled_exception = None ;
447+ } else if let Some ( current_context) = current_handled_exception. as_ref ( ) {
448+ // Check if this exception is already in the context chain, so we do not create reference cycles in the context chain.
449+ let mut iter = context_chain_iter ( current_context. clone ( ) ) . peekable ( ) ;
450+ while let Some ( ( current, next) ) = iter. next ( ) . zip ( iter. peek ( ) ) {
451+ if next. is ( & exc) {
452+ // Loop in context chain, breaking the loop by not pointing to exc
453+ unsafe { ffi:: PyException_SetContext ( current. as_ptr ( ) , std:: ptr:: null_mut ( ) ) } ;
454+ break ;
455+ }
456+ }
457+ }
458+
459+ Ok ( exc)
433460 } else if pvalue. is_none ( ) {
434461 // None -> no arguments
435462 ptype. call0 ( ) . and_then ( |pvalue| Ok ( pvalue. cast_into ( ) ?) )
@@ -445,19 +472,67 @@ fn create_normalized_exception<'py>(
445472
446473 match pvalue {
447474 Ok ( pvalue) => {
448- unsafe {
449- if let Some ( context) = NonNull :: new ( ffi:: PyErr_GetHandledException ( ) ) {
450- ffi:: PyException_SetContext ( pvalue. as_ptr ( ) , context. as_ptr ( ) )
451- }
452- } ;
475+ // Implicitly set the context of the new exception to the currently handled exception, if any.
476+ if let Some ( context) = current_handled_exception {
477+ unsafe { ffi:: PyException_SetContext ( pvalue. as_ptr ( ) , context. into_ptr ( ) ) } ;
478+ }
453479 pvalue
454480 }
455481 Err ( e) => e. value ( py) . clone ( ) ,
456482 }
457483}
458484
485+ /// Iterates through the context chain of exceptions, starting from `start`, and yields each exception in the chain.
486+ /// When there is a loop in the chain it may yield some elements multiple times, but it will always terminate.
487+ #[ inline]
488+ #[ cfg( Py_3_12 ) ]
489+ fn context_chain_iter (
490+ start : Bound < ' _ , PyBaseException > ,
491+ ) -> impl Iterator < Item = Bound < ' _ , PyBaseException > > {
492+ #[ inline]
493+ fn get_next < ' py > ( current : & Bound < ' py , PyBaseException > ) -> Option < Bound < ' py , PyBaseException > > {
494+ unsafe {
495+ ffi:: PyException_GetContext ( current. as_ptr ( ) )
496+ . assume_owned_or_opt ( current. py ( ) )
497+ . map ( |obj| obj. cast_into_unchecked ( ) )
498+ }
499+ }
500+
501+ let mut slow = None ;
502+ let mut current = Some ( start) ;
503+ let mut slow_update_toggle = false ;
504+
505+ std:: iter:: from_fn ( move || {
506+ let next = get_next ( current. as_ref ( ) ?) ;
507+
508+ // Detect loops in the context chain using Floyd's Tortoise and Hare algorithm.
509+ if let Some ( ( current_slow, current_fast) ) = slow. as_ref ( ) . zip ( next. as_ref ( ) ) {
510+ if current_fast. is ( current_slow) {
511+ // Loop detected
512+ return current. take ( ) ;
513+ }
514+
515+ // Every second iteration, advance the slow pointer by one step
516+ if slow_update_toggle {
517+ slow = get_next ( current_slow) ;
518+ }
519+
520+ slow_update_toggle = !slow_update_toggle;
521+ }
522+
523+ // Set the slow pointer after the first iteration
524+ if slow. is_none ( ) {
525+ slow = current. clone ( )
526+ }
527+
528+ std:: mem:: replace ( & mut current, next)
529+ } )
530+ }
531+
459532#[ cfg( test) ]
460533mod tests {
534+ #[ cfg( Py_3_12 ) ]
535+ use crate :: { exceptions:: PyBaseException , ffi, Bound } ;
461536 use crate :: {
462537 exceptions:: PyValueError , sync:: PyOnceLock , Py , PyAny , PyErr , PyErrArguments , Python ,
463538 } ;
@@ -571,7 +646,10 @@ mod tests {
571646 Bound ,
572647 } ;
573648
574- fn test_exception < ' py > ( ptype : & Bound < ' py , PyAny > , pvalue : Bound < ' py , PyAny > ) {
649+ fn test_exception < ' py > (
650+ ptype : & Bound < ' py , PyAny > ,
651+ pvalue : Bound < ' py , PyAny > ,
652+ ) -> ( PyErr , PyErr ) {
575653 let py = ptype. py ( ) ;
576654
577655 let exc1 = super :: create_normalized_exception ( ptype, pvalue. clone ( ) ) ;
@@ -592,6 +670,15 @@ mod tests {
592670 assert ! ( err1. traceback( py) . xor( err2. traceback( py) ) . is_none( ) ) ;
593671 assert ! ( err1. cause( py) . xor( err2. cause( py) ) . is_none( ) ) ;
594672 assert_eq ! ( err1. to_string( ) , err2. to_string( ) ) ;
673+
674+ super :: context_chain_iter ( err1. value ( py) . clone ( ) )
675+ . zip ( super :: context_chain_iter ( err2. value ( py) . clone ( ) ) )
676+ . for_each ( |( context1, context2) | {
677+ assert ! ( context1. get_type( ) . is( context2. get_type( ) ) ) ;
678+ assert_eq ! ( context1. to_string( ) , context2. to_string( ) ) ;
679+ } ) ;
680+
681+ ( err1, err2)
595682 }
596683
597684 Python :: attach ( |py| {
@@ -614,6 +701,99 @@ mod tests {
614701 . into_any ( )
615702 . into_bound ( py) ,
616703 ) ;
704+
705+ // Loop where err is not part of the loop
706+ let looped_context = create_loop ( py, 3 ) ;
707+ let err = PyRuntimeError :: new_err ( "Boom" ) ;
708+ with_handled_exception ( looped_context. value ( py) , || {
709+ let ( normalized, _) = test_exception (
710+ & PyRuntimeError :: type_object ( py) ,
711+ err. value ( py) . clone ( ) . into_any ( ) ,
712+ ) ;
713+
714+ assert ! ( normalized
715+ . context( py)
716+ . unwrap( )
717+ . value( py)
718+ . is( looped_context. value( py) ) ) ;
719+ } ) ;
720+
721+ // loop where err is part of the loop
722+ let err_a = PyRuntimeError :: new_err ( "A" ) ;
723+ let err_b = PyRuntimeError :: new_err ( "B" ) ;
724+ // a -> b -> a
725+ err_a. set_context ( py, Some ( err_b. clone_ref ( py) ) ) ;
726+ err_b. set_context ( py, Some ( err_a. clone_ref ( py) ) ) ;
727+ // handled = raised = a
728+ with_handled_exception ( err_a. value ( py) , || {
729+ let ( rust_normal, py_normal) = test_exception (
730+ & PyRuntimeError :: type_object ( py) ,
731+ err_a. value ( py) . clone ( ) . into_any ( ) ,
732+ ) ;
733+
734+ // a.context -> b
735+ assert ! ( rust_normal
736+ . context( py)
737+ . unwrap( )
738+ . value( py)
739+ . is( err_b. value( py) ) ) ;
740+ assert ! ( py_normal. context( py) . unwrap( ) . value( py) . is( err_b. value( py) ) ) ;
741+ } ) ;
742+
743+ // no loop yet, but implicit context will loop if we set a.context = b
744+ let err_a = PyRuntimeError :: new_err ( "A" ) ;
745+ let err_b = PyRuntimeError :: new_err ( "B" ) ;
746+ err_b. set_context ( py, Some ( err_a. clone_ref ( py) ) ) ;
747+ // raised = a, handled = b
748+ with_handled_exception ( err_b. value ( py) , || {
749+ test_exception (
750+ & PyRuntimeError :: type_object ( py) ,
751+ err_b. value ( py) . clone ( ) . into_any ( ) ,
752+ ) ;
753+ } ) ;
754+ } )
755+ }
756+
757+ #[ cfg( Py_3_12 ) ]
758+ fn with_handled_exception ( exc : & Bound < ' _ , PyBaseException > , f : impl FnOnce ( ) ) {
759+ struct Guard ;
760+ impl Drop for Guard {
761+ fn drop ( & mut self ) {
762+ unsafe { ffi:: PyErr_SetHandledException ( std:: ptr:: null_mut ( ) ) } ;
763+ }
764+ }
765+
766+ let guard = Guard ;
767+ unsafe { ffi:: PyErr_SetHandledException ( exc. as_ptr ( ) ) } ;
768+ f ( ) ;
769+ drop ( guard) ;
770+ }
771+
772+ #[ cfg( Py_3_12 ) ]
773+ fn create_loop ( py : Python < ' _ > , size : usize ) -> PyErr {
774+ let first = PyValueError :: new_err ( "exc0" ) ;
775+ let last = ( 1 ..size) . fold ( first. clone_ref ( py) , |prev, i| {
776+ let exc = PyValueError :: new_err ( format ! ( "exc{i}" ) ) ;
777+ prev. set_context ( py, Some ( exc. clone_ref ( py) ) ) ;
778+ exc
779+ } ) ;
780+ last. set_context ( py, Some ( first. clone_ref ( py) ) ) ;
781+
782+ first
783+ }
784+
785+ #[ test]
786+ #[ cfg( Py_3_12 ) ]
787+ fn test_context_chain_iter_terminates ( ) {
788+ Python :: attach ( |py| {
789+ for size in 1 ..=8 {
790+ let chain = create_loop ( py, size) ;
791+ let count = super :: context_chain_iter ( chain. into_value ( py) . into_bound ( py) ) . count ( ) ;
792+ assert ! (
793+ count >= size,
794+ "We should have seen each element at least once"
795+ ) ;
796+ }
617797 } )
618798 }
619799}
0 commit comments