@@ -6,6 +6,8 @@ use std::{
66
77#[ cfg( not( Py_3_12 ) ) ]
88use crate :: sync:: MutexExt ;
9+ #[ cfg( Py_3_12 ) ]
10+ use crate :: types:: { PyString , PyTuple } ;
911use crate :: {
1012 exceptions:: { PyBaseException , PyTypeError } ,
1113 ffi,
@@ -383,26 +385,72 @@ fn lazy_into_normalized_ffi_tuple(
383385}
384386
385387/// Raises a "lazy" exception state into the Python interpreter.
386- ///
387- /// In principle this could be split in two; first a function to create an exception
388- /// in a normalized state, and then a call to `PyErr_SetRaisedException` to raise it.
389- ///
390- /// This would require either moving some logic from C to Rust, or requesting a new
391- /// API in CPython.
392388fn raise_lazy ( py : Python < ' _ > , lazy : Box < PyErrStateLazyFn > ) {
393389 let PyErrStateLazyFnOutput { ptype, pvalue } = lazy ( py) ;
390+
394391 unsafe {
392+ #[ cfg( not( Py_3_12 ) ) ]
395393 if ffi:: PyExceptionClass_Check ( ptype. as_ptr ( ) ) == 0 {
396394 ffi:: PyErr_SetString (
397395 PyTypeError :: type_object_raw ( py) . cast ( ) ,
398396 c"exceptions must derive from BaseException" . as_ptr ( ) ,
399- )
397+ ) ;
400398 } else {
401- ffi:: PyErr_SetObject ( ptype. as_ptr ( ) , pvalue. as_ptr ( ) )
399+ ffi:: PyErr_SetObject ( ptype. as_ptr ( ) , pvalue. as_ptr ( ) ) ;
400+ }
401+
402+ #[ cfg( Py_3_12 ) ]
403+ {
404+ let exc = create_normalized_exception ( ptype. bind ( py) , pvalue. into_bound ( py) ) ;
405+
406+ ffi:: PyErr_SetRaisedException ( exc. into_ptr ( ) ) ;
402407 }
403408 }
404409}
405410
411+ #[ cfg( Py_3_12 ) ]
412+ fn create_normalized_exception < ' py > (
413+ ptype : & Bound < ' py , PyAny > ,
414+ mut pvalue : Bound < ' py , PyAny > ,
415+ ) -> Bound < ' py , PyBaseException > {
416+ let py = ptype. py ( ) ;
417+
418+ // 1: check type is a subclass of BaseException
419+ let ptype: Bound < ' py , PyType > = if unsafe { ffi:: PyExceptionClass_Check ( ptype. as_ptr ( ) ) } == 0 {
420+ pvalue = PyString :: new ( py, "exceptions must derive from BaseException" ) . into_any ( ) ;
421+ PyTypeError :: type_object ( py)
422+ } else {
423+ // Safety: PyExceptionClass_Check guarantees that ptype is a subclass of BaseException
424+ unsafe { ptype. cast_unchecked ( ) } . clone ( )
425+ } ;
426+
427+ let pvalue = if pvalue. is_exact_instance ( & ptype) {
428+ // Safety: already an exception value of the correct type
429+ Ok ( unsafe { pvalue. cast_into_unchecked :: < PyBaseException > ( ) } )
430+ } else if pvalue. is_none ( ) {
431+ // None -> no arguments
432+ ptype. call0 ( ) . and_then ( |pvalue| Ok ( pvalue. cast_into ( ) ?) )
433+ } else if let Ok ( tup) = pvalue. cast :: < PyTuple > ( ) {
434+ // Tuple -> use as tuple of arguments
435+ ptype. call1 ( tup) . and_then ( |pvalue| Ok ( pvalue. cast_into ( ) ?) )
436+ } else {
437+ // Anything else -> use as single argument
438+ ptype
439+ . call1 ( ( pvalue, ) )
440+ . and_then ( |pvalue| Ok ( pvalue. cast_into ( ) ?) )
441+ } ;
442+
443+ match pvalue {
444+ Ok ( pvalue) => {
445+ unsafe {
446+ ffi:: PyException_SetContext ( pvalue. as_ptr ( ) , ffi:: PyErr_GetHandledException ( ) )
447+ } ;
448+ pvalue
449+ }
450+ Err ( e) => e. value ( py) . clone ( ) ,
451+ }
452+ }
453+
406454#[ cfg( test) ]
407455mod tests {
408456
@@ -478,4 +526,35 @@ mod tests {
478526 . is_instance_of:: <PyValueError >( py) )
479527 } ) ;
480528 }
529+
530+ #[ test]
531+ #[ cfg( feature = "macros" ) ]
532+ fn test_new_exception_context ( ) {
533+ use crate :: {
534+ exceptions:: { PyRuntimeError , PyValueError } ,
535+ pyfunction,
536+ types:: { PyDict , PyDictMethods } ,
537+ wrap_pyfunction, PyResult ,
538+ } ;
539+ #[ pyfunction( crate = "crate" ) ]
540+ fn throw_exception ( ) -> PyResult < ( ) > {
541+ Err ( PyValueError :: new_err ( "error happened" ) )
542+ }
543+
544+ Python :: attach ( |py| {
545+ let globals = PyDict :: new ( py) ;
546+ let f = wrap_pyfunction ! ( throw_exception, py) . unwrap ( ) ;
547+ globals. set_item ( "throw_exception" , f) . unwrap ( ) ;
548+ let err = py
549+ . run (
550+ c"try:\n raise RuntimeError()\n except:\n throw_exception()\n " ,
551+ Some ( & globals) ,
552+ None ,
553+ )
554+ . unwrap_err ( ) ;
555+
556+ let context = err. context ( py) . unwrap ( ) ;
557+ assert ! ( context. is_instance_of:: <PyRuntimeError >( py) )
558+ } )
559+ }
481560}
0 commit comments