@@ -6,18 +6,15 @@ use std::{
66
77#[ cfg( not( Py_3_12 ) ) ]
88use crate :: sync:: MutexExt ;
9+ #[ cfg( Py_3_12 ) ]
10+ use crate :: types:: { PyString , PyTuple } ;
911use crate :: {
1012 exceptions:: { PyBaseException , PyTypeError } ,
1113 ffi,
1214 ffi_ptr_ext:: FfiPtrExt ,
1315 types:: { PyAnyMethods , PyTraceback , PyType } ,
1416 Bound , Py , PyAny , PyErrArguments , PyTypeInfo , Python ,
1517} ;
16- #[ cfg( Py_3_12 ) ]
17- use {
18- crate :: types:: { PyString , PyTuple } ,
19- std:: ptr:: NonNull ,
20- } ;
2118
2219pub ( crate ) struct PyErrState {
2320 // Safety: can only hand out references when in the "normalized" state. Will never change
@@ -427,9 +424,36 @@ fn create_normalized_exception<'py>(
427424 unsafe { ptype. cast_unchecked ( ) } . clone ( )
428425 } ;
429426
427+ let mut current_handled_exception: Option < Bound < ' _ , PyBaseException > > = unsafe {
428+ ffi:: PyErr_GetHandledException ( )
429+ . assume_owned_or_opt ( py)
430+ . map ( |obj| obj. cast_into_unchecked ( ) )
431+ } ;
432+
430433 let pvalue = if pvalue. is_exact_instance ( & ptype) {
431434 // Safety: already an exception value of the correct type
432- Ok ( unsafe { pvalue. cast_into_unchecked :: < PyBaseException > ( ) } )
435+ let exc = unsafe { pvalue. cast_into_unchecked :: < PyBaseException > ( ) } ;
436+
437+ if current_handled_exception
438+ . as_ref ( )
439+ . map ( |current| current. is ( & exc) )
440+ . unwrap_or_default ( )
441+ {
442+ // Current exception is the same as it's context so do not set the context to avoid a loop
443+ current_handled_exception = None ;
444+ } else if let Some ( current_context) = current_handled_exception. as_ref ( ) {
445+ // Check if this exception is already in the context chain, so we do not create reference cycles in the context chain.
446+ let mut iter = context_chain_iter ( current_context. clone ( ) ) . peekable ( ) ;
447+ while let Some ( ( current, next) ) = iter. next ( ) . zip ( iter. peek ( ) ) {
448+ if next. is ( & exc) {
449+ // Loop in context chain, breaking the loop by not pointing to exc
450+ unsafe { ffi:: PyException_SetContext ( current. as_ptr ( ) , std:: ptr:: null_mut ( ) ) } ;
451+ break ;
452+ }
453+ }
454+ }
455+
456+ Ok ( exc)
433457 } else if pvalue. is_none ( ) {
434458 // None -> no arguments
435459 ptype. call0 ( ) . and_then ( |pvalue| Ok ( pvalue. cast_into ( ) ?) )
@@ -445,19 +469,67 @@ fn create_normalized_exception<'py>(
445469
446470 match pvalue {
447471 Ok ( pvalue) => {
448- unsafe {
449- if let Some ( context) = NonNull :: new ( ffi:: PyErr_GetHandledException ( ) ) {
450- ffi:: PyException_SetContext ( pvalue. as_ptr ( ) , context. as_ptr ( ) )
451- }
452- } ;
472+ // Implicitly set the context of the new exception to the currently handled exception, if any.
473+ if let Some ( context) = current_handled_exception {
474+ unsafe { ffi:: PyException_SetContext ( pvalue. as_ptr ( ) , context. into_ptr ( ) ) } ;
475+ }
453476 pvalue
454477 }
455478 Err ( e) => e. value ( py) . clone ( ) ,
456479 }
457480}
458481
482+ /// Iterates through the context chain of exceptions, starting from `start`, and yields each exception in the chain.
483+ /// When there is a loop in the chain it may yield some elements multiple times, but it will always terminate.
484+ #[ inline]
485+ #[ cfg( Py_3_12 ) ]
486+ fn context_chain_iter (
487+ start : Bound < ' _ , PyBaseException > ,
488+ ) -> impl Iterator < Item = Bound < ' _ , PyBaseException > > {
489+ #[ inline]
490+ fn get_next < ' py > ( current : & Bound < ' py , PyBaseException > ) -> Option < Bound < ' py , PyBaseException > > {
491+ unsafe {
492+ ffi:: PyException_GetContext ( current. as_ptr ( ) )
493+ . assume_owned_or_opt ( current. py ( ) )
494+ . map ( |obj| obj. cast_into_unchecked ( ) )
495+ }
496+ }
497+
498+ let mut slow = None ;
499+ let mut current = Some ( start) ;
500+ let mut slow_update_toggle = false ;
501+
502+ std:: iter:: from_fn ( move || {
503+ let next = get_next ( current. as_ref ( ) ?) ;
504+
505+ // Detect loops in the context chain using Floyd's Tortoise and Hare algorithm.
506+ if let Some ( ( current_slow, current_fast) ) = slow. as_ref ( ) . zip ( next. as_ref ( ) ) {
507+ if current_fast. is ( current_slow) {
508+ // Loop detected
509+ return current. take ( ) ;
510+ }
511+
512+ // Every second iteration, advance the slow pointer by one step
513+ if slow_update_toggle {
514+ slow = get_next ( current_slow) ;
515+ }
516+
517+ slow_update_toggle = !slow_update_toggle;
518+ }
519+
520+ // Set the slow pointer after the first iteration
521+ if slow. is_none ( ) {
522+ slow = current. clone ( )
523+ }
524+
525+ std:: mem:: replace ( & mut current, next)
526+ } )
527+ }
528+
459529#[ cfg( test) ]
460530mod tests {
531+ #[ cfg( Py_3_12 ) ]
532+ use crate :: { exceptions:: PyBaseException , ffi, Bound } ;
461533 use crate :: {
462534 exceptions:: PyValueError , sync:: PyOnceLock , Py , PyAny , PyErr , PyErrArguments , Python ,
463535 } ;
@@ -571,7 +643,10 @@ mod tests {
571643 Bound ,
572644 } ;
573645
574- fn test_exception < ' py > ( ptype : & Bound < ' py , PyAny > , pvalue : Bound < ' py , PyAny > ) {
646+ fn test_exception < ' py > (
647+ ptype : & Bound < ' py , PyAny > ,
648+ pvalue : Bound < ' py , PyAny > ,
649+ ) -> ( PyErr , PyErr ) {
575650 let py = ptype. py ( ) ;
576651
577652 let exc1 = super :: create_normalized_exception ( ptype, pvalue. clone ( ) ) ;
@@ -592,6 +667,15 @@ mod tests {
592667 assert ! ( err1. traceback( py) . xor( err2. traceback( py) ) . is_none( ) ) ;
593668 assert ! ( err1. cause( py) . xor( err2. cause( py) ) . is_none( ) ) ;
594669 assert_eq ! ( err1. to_string( ) , err2. to_string( ) ) ;
670+
671+ super :: context_chain_iter ( err1. value ( py) . clone ( ) )
672+ . zip ( super :: context_chain_iter ( err2. value ( py) . clone ( ) ) )
673+ . for_each ( |( context1, context2) | {
674+ assert ! ( context1. get_type( ) . is( context2. get_type( ) ) ) ;
675+ assert_eq ! ( context1. to_string( ) , context2. to_string( ) ) ;
676+ } ) ;
677+
678+ ( err1, err2)
595679 }
596680
597681 Python :: attach ( |py| {
@@ -614,6 +698,99 @@ mod tests {
614698 . into_any ( )
615699 . into_bound ( py) ,
616700 ) ;
701+
702+ // Loop where err is not part of the loop
703+ let looped_context = create_loop ( py, 3 ) ;
704+ let err = PyRuntimeError :: new_err ( "Boom" ) ;
705+ with_handled_exception ( looped_context. value ( py) , || {
706+ let ( normalized, _) = test_exception (
707+ & PyRuntimeError :: type_object ( py) ,
708+ err. value ( py) . clone ( ) . into_any ( ) ,
709+ ) ;
710+
711+ assert ! ( normalized
712+ . context( py)
713+ . unwrap( )
714+ . value( py)
715+ . is( looped_context. value( py) ) ) ;
716+ } ) ;
717+
718+ // loop where err is part of the loop
719+ let err_a = PyRuntimeError :: new_err ( "A" ) ;
720+ let err_b = PyRuntimeError :: new_err ( "B" ) ;
721+ // a -> b -> a
722+ err_a. set_context ( py, Some ( err_b. clone_ref ( py) ) ) ;
723+ err_b. set_context ( py, Some ( err_a. clone_ref ( py) ) ) ;
724+ // handled = raised = a
725+ with_handled_exception ( err_a. value ( py) , || {
726+ let ( rust_normal, py_normal) = test_exception (
727+ & PyRuntimeError :: type_object ( py) ,
728+ err_a. value ( py) . clone ( ) . into_any ( ) ,
729+ ) ;
730+
731+ // a.context -> b
732+ assert ! ( rust_normal
733+ . context( py)
734+ . unwrap( )
735+ . value( py)
736+ . is( err_b. value( py) ) ) ;
737+ assert ! ( py_normal. context( py) . unwrap( ) . value( py) . is( err_b. value( py) ) ) ;
738+ } ) ;
739+
740+ // no loop yet, but implicit context will loop if we set a.context = b
741+ let err_a = PyRuntimeError :: new_err ( "A" ) ;
742+ let err_b = PyRuntimeError :: new_err ( "B" ) ;
743+ err_b. set_context ( py, Some ( err_a. clone_ref ( py) ) ) ;
744+ // raised = a, handled = b
745+ with_handled_exception ( err_b. value ( py) , || {
746+ test_exception (
747+ & PyRuntimeError :: type_object ( py) ,
748+ err_b. value ( py) . clone ( ) . into_any ( ) ,
749+ ) ;
750+ } ) ;
751+ } )
752+ }
753+
754+ #[ cfg( Py_3_12 ) ]
755+ fn with_handled_exception ( exc : & Bound < ' _ , PyBaseException > , f : impl FnOnce ( ) ) {
756+ struct Guard ;
757+ impl Drop for Guard {
758+ fn drop ( & mut self ) {
759+ unsafe { ffi:: PyErr_SetHandledException ( std:: ptr:: null_mut ( ) ) } ;
760+ }
761+ }
762+
763+ let guard = Guard ;
764+ unsafe { ffi:: PyErr_SetHandledException ( exc. as_ptr ( ) ) } ;
765+ f ( ) ;
766+ drop ( guard) ;
767+ }
768+
769+ #[ cfg( Py_3_12 ) ]
770+ fn create_loop ( py : Python < ' _ > , size : usize ) -> PyErr {
771+ let first = PyValueError :: new_err ( "exc0" ) ;
772+ let last = ( 1 ..size) . fold ( first. clone_ref ( py) , |prev, i| {
773+ let exc = PyValueError :: new_err ( format ! ( "exc{i}" ) ) ;
774+ prev. set_context ( py, Some ( exc. clone_ref ( py) ) ) ;
775+ exc
776+ } ) ;
777+ last. set_context ( py, Some ( first. clone_ref ( py) ) ) ;
778+
779+ first
780+ }
781+
782+ #[ test]
783+ #[ cfg( Py_3_12 ) ]
784+ fn test_context_chain_iter_terminates ( ) {
785+ Python :: attach ( |py| {
786+ for size in 1 ..=8 {
787+ let chain = create_loop ( py, size) ;
788+ let count = super :: context_chain_iter ( chain. into_value ( py) . into_bound ( py) ) . count ( ) ;
789+ assert ! (
790+ count >= size,
791+ "We should have seen each element at least once"
792+ ) ;
793+ }
617794 } )
618795 }
619796}
0 commit comments