@@ -207,6 +207,29 @@ pub trait PyDictMethods<'py>: crate::sealed::Sealed {
207207 /// This method uses [`PyDict_Merge`](https://docs.python.org/3/c-api/dict.html#c.PyDict_Merge) internally,
208208 /// so should have the same performance as `update`.
209209 fn update_if_missing ( & self , other : & Bound < ' _ , PyMapping > ) -> PyResult < ( ) > ;
210+
211+ /// Inserts `default_value` into this dictionary with a key of `key` if the key is not already present in the
212+ /// dictionary. If the key was inserted, returns Ok(true), otherwise returns Ok(false), indicating the key was
213+ /// already present. If an error happens, returns PyErr. This function uses
214+ /// [`PyDict_SetDefaultRef`](https://docs.python.org/3/c-api/dict.html#c.PyDict_SetDefaultRef) internally.
215+ fn set_default < K , V > ( & self , key : K , default_value : V ) -> PyResult < bool >
216+ where
217+ K : IntoPyObject < ' py > ,
218+ V : IntoPyObject < ' py > ;
219+
220+ /// Inserts `default_value` into this dictionary with a key of `key` if the key is not already present in the
221+ /// dictionary. If the key was inserted, returns Ok((true, result)), otherwise returns Ok((false, result)) where
222+ /// `result` is the `value` associated with `key` after this function finishes. If an error happens, returns
223+ /// PyErr. This function uses
224+ /// [`PyDict_SetDefaultRef`](https://docs.python.org/3/c-api/dict.html#c.PyDict_SetDefaultRef) internally.
225+ fn set_default_with_result < K , V > (
226+ & self ,
227+ key : K ,
228+ default_value : V ,
229+ ) -> PyResult < ( bool , Bound < ' py , PyAny > ) >
230+ where
231+ K : IntoPyObject < ' py > ,
232+ V : IntoPyObject < ' py > ;
210233}
211234
212235impl < ' py > PyDictMethods < ' py > for Bound < ' py , PyDict > {
@@ -385,6 +408,92 @@ impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> {
385408 ffi:: PyDict_Merge ( self . as_ptr ( ) , other. as_ptr ( ) , 0 )
386409 } )
387410 }
411+
412+ fn set_default < K , V > ( & self , key : K , default_value : V ) -> PyResult < bool >
413+ where
414+ K : IntoPyObject < ' py > ,
415+ V : IntoPyObject < ' py > ,
416+ {
417+ fn inner (
418+ dict : & Bound < ' _ , PyDict > ,
419+ key : Borrowed < ' _ , ' _ , PyAny > ,
420+ value : Borrowed < ' _ , ' _ , PyAny > ,
421+ ) -> PyResult < bool > {
422+ setdefault_result_from_nonerror_return_code ( err:: error_on_minusone_with_result (
423+ dict. py ( ) ,
424+ unsafe {
425+ ffi:: compat:: PyDict_SetDefaultRef (
426+ dict. as_ptr ( ) ,
427+ key. as_ptr ( ) ,
428+ value. as_ptr ( ) ,
429+ std:: ptr:: null_mut ( ) ,
430+ )
431+ } ,
432+ ) )
433+ }
434+ let py = self . py ( ) ;
435+
436+ inner (
437+ self ,
438+ key. into_pyobject_or_pyerr ( py) ?. into_any ( ) . as_borrowed ( ) ,
439+ default_value
440+ . into_pyobject_or_pyerr ( py) ?
441+ . into_any ( )
442+ . as_borrowed ( ) ,
443+ )
444+ }
445+
446+ fn set_default_with_result < K , V > (
447+ & self ,
448+ key : K ,
449+ default_value : V ,
450+ ) -> PyResult < ( bool , Bound < ' py , PyAny > ) >
451+ where
452+ K : IntoPyObject < ' py > ,
453+ V : IntoPyObject < ' py > ,
454+ {
455+ fn inner < ' py > (
456+ dict : & Bound < ' _ , PyDict > ,
457+ key : Borrowed < ' _ , ' _ , PyAny > ,
458+ value : Borrowed < ' _ , ' _ , PyAny > ,
459+ py : Python < ' py > ,
460+ ) -> PyResult < ( bool , Bound < ' py , PyAny > ) > {
461+ let mut result = std:: ptr:: NonNull :: dangling ( ) . as_ptr ( ) ;
462+ let code = setdefault_result_from_nonerror_return_code (
463+ err:: error_on_minusone_with_result ( dict. py ( ) , unsafe {
464+ ffi:: compat:: PyDict_SetDefaultRef (
465+ dict. as_ptr ( ) ,
466+ key. as_ptr ( ) ,
467+ value. as_ptr ( ) ,
468+ & mut result,
469+ )
470+ } ) ,
471+ ) ?;
472+ // SAFETY: the interpreter should have set this to a valid owned PyObject pointer
473+ let out_result = unsafe { result. assume_owned_unchecked ( py) } ;
474+ Ok ( ( code, out_result) )
475+ }
476+ let py = self . py ( ) ;
477+ inner (
478+ self ,
479+ key. into_pyobject_or_pyerr ( py) ?. into_any ( ) . as_borrowed ( ) ,
480+ default_value
481+ . into_pyobject_or_pyerr ( py) ?
482+ . into_any ( )
483+ . as_borrowed ( ) ,
484+ py,
485+ )
486+ }
487+ }
488+
489+ fn setdefault_result_from_nonerror_return_code ( code : PyResult < std:: ffi:: c_int > ) -> PyResult < bool > {
490+ match code? {
491+ // inserted
492+ 0 => Ok ( true ) ,
493+ // not inserted
494+ 1 => Ok ( false ) ,
495+ x => panic ! ( "Unknown return value from PyDict_SetDefaultRef: {x}" ) ,
496+ }
388497}
389498
390499impl < ' a , ' py > Borrowed < ' a , ' py , PyDict > {
@@ -1669,4 +1778,54 @@ mod tests {
16691778 assert_eq ! ( dict. iter( ) . count( ) , 3 ) ;
16701779 } )
16711780 }
1781+
1782+ #[ test]
1783+ fn test_set_default ( ) {
1784+ Python :: attach ( |py| {
1785+ let dict = PyDict :: new ( py) ;
1786+ assert ! ( matches!( dict. set_default( "hello" , "world" ) , Ok ( true ) ) ) ;
1787+ assert_eq ! (
1788+ dict. get_item( "hello" )
1789+ . unwrap( )
1790+ . unwrap( )
1791+ . extract:: <String >( )
1792+ . unwrap( ) ,
1793+ "world"
1794+ ) ;
1795+
1796+ assert ! ( matches!( dict. set_default( "hello" , "foobar" ) , Ok ( false ) ) ) ;
1797+
1798+ // unhashable
1799+ let invalid_key = PyList :: new ( py, vec ! [ 0 ] ) . unwrap ( ) ;
1800+ assert ! ( dict. set_default( invalid_key, "foobar" ) . is_err( ) ) ;
1801+ } )
1802+ }
1803+
1804+ #[ test]
1805+ fn test_set_default_with_result ( ) {
1806+ Python :: attach ( |py| {
1807+ let dict = PyDict :: new ( py) ;
1808+ let res = dict. set_default_with_result ( "hello" , "world" ) ;
1809+ assert ! ( res. is_ok( ) ) ;
1810+ let ( inserted, value) = res. unwrap ( ) ;
1811+ assert ! ( inserted) ;
1812+ assert ! ( value. extract:: <String >( ) . unwrap( ) == "world" ) ;
1813+ assert ! (
1814+ dict. get_item( "hello" )
1815+ . unwrap( )
1816+ . unwrap( )
1817+ . extract:: <String >( )
1818+ . unwrap( )
1819+ == "world"
1820+ ) ;
1821+
1822+ let ( inserted, value) = dict. set_default_with_result ( "hello" , "foobar" ) . unwrap ( ) ;
1823+ assert ! ( !inserted) ;
1824+ assert_eq ! ( value. extract:: <String >( ) . unwrap( ) , "world" ) ;
1825+
1826+ // unhashable
1827+ let invalid_key = PyList :: new ( py, vec ! [ 0 ] ) . unwrap ( ) ;
1828+ assert ! ( dict. set_default_with_result( invalid_key, "foobar" ) . is_err( ) ) ;
1829+ } )
1830+ }
16721831}
0 commit comments