diff --git a/newsfragments/5955.added.md b/newsfragments/5955.added.md new file mode 100644 index 00000000000..39af36e4c17 --- /dev/null +++ b/newsfragments/5955.added.md @@ -0,0 +1,6 @@ +* Added FFI bindings for `PyDict_SetDefaultRef` for Python 3.13 and newer and + 3.15 and newer in the limited API. Also added a compat shim for older Python + versions. + +* Added `PyDictMethods::set_default` and `PyDictMethods::set_default_ref` to + allow atomically setting default values in a PyDict. diff --git a/newsfragments/5955.changed.md b/newsfragments/5955.changed.md new file mode 100644 index 00000000000..acb04bb2f4f --- /dev/null +++ b/newsfragments/5955.changed.md @@ -0,0 +1,2 @@ +`PyCode::run` uses PyDict_SetDefaultRef instead of a critical section to ensure +thread safety when it adds a reference to `__builtins__` to the globals dict. diff --git a/pyo3-ffi/src/compat/py_3_13.rs b/pyo3-ffi/src/compat/py_3_13.rs index 08bdf5cba18..26b9f2698fb 100644 --- a/pyo3-ffi/src/compat/py_3_13.rs +++ b/pyo3-ffi/src/compat/py_3_13.rs @@ -130,3 +130,50 @@ compat_function!( crate::_PyThreadState_UncheckedGet() } ); + +compat_function!( + originally_defined_for(all(Py_3_13, any(not(Py_LIMITED_API), Py_3_15))); + + #[inline] + pub unsafe fn PyDict_SetDefaultRef( + mp: *mut crate::PyObject, + key: *mut crate::PyObject, + default_value: *mut crate::PyObject, + result: *mut *mut crate::PyObject, + ) -> std::ffi::c_int { + use crate::{ + compat::{PyDict_GetItemRef, Py_NewRef}, + PyDict_SetItem, PyObject, Py_DECREF, + }; + let mut value: *mut PyObject = std::ptr::null_mut(); + if PyDict_GetItemRef(mp, key, &mut value) < 0 { + // get error + if !result.is_null() { + *result = std::ptr::null_mut(); + } + return -1; + } + if !value.is_null() { + // present + if !result.is_null() { + *result = value; + } else { + Py_DECREF(value); + } + return 1; + } + + // missing, set the item + if PyDict_SetItem(mp, key, default_value) < 0 { + // set error + if !result.is_null() { + *result = std::ptr::null_mut(); + } + return -1; + } + if !result.is_null() { + *result = Py_NewRef(default_value); + } + 0 + } +); diff --git a/pyo3-ffi/src/cpython/dictobject.rs b/pyo3-ffi/src/cpython/dictobject.rs index 37991ee4ebe..4c93f068e2b 100644 --- a/pyo3-ffi/src/cpython/dictobject.rs +++ b/pyo3-ffi/src/cpython/dictobject.rs @@ -43,7 +43,6 @@ pub struct PyDictObject { // skipped private _PyDict_GetItemStringWithError // skipped PyDict_SetDefault -// skipped PyDict_SetDefaultRef // skipped PyDict_GET_SIZE // skipped PyDict_ContainsString diff --git a/pyo3-ffi/src/dictobject.rs b/pyo3-ffi/src/dictobject.rs index 4afa2ffb5e8..4df1265d4b8 100644 --- a/pyo3-ffi/src/dictobject.rs +++ b/pyo3-ffi/src/dictobject.rs @@ -78,6 +78,13 @@ extern_libpython! { key: *const c_char, result: *mut *mut PyObject, ) -> c_int; + #[cfg(all(Py_3_13, any(not(Py_LIMITED_API), Py_3_15)))] + pub fn PyDict_SetDefaultRef( + mp: *mut PyObject, + key: *mut PyObject, + default_value: *mut PyObject, + result: *mut *mut PyObject, + ) -> c_int; // skipped 3.10 / ex-non-limited PyObject_GenericGetDict } diff --git a/src/err/mod.rs b/src/err/mod.rs index 53f97fa1c45..43d006d2332 100644 --- a/src/err/mod.rs +++ b/src/err/mod.rs @@ -761,6 +761,19 @@ pub(crate) fn error_on_minusone(py: Python<'_>, result: T) -> } } +/// Returns Ok wrapping the result if the error code is not -1. +#[inline] +pub(crate) fn error_on_minusone_with_result( + py: Python<'_>, + result: T, +) -> PyResult { + if result != T::MINUS_ONE { + Ok(result) + } else { + Err(PyErr::fetch(py)) + } +} + pub(crate) trait SignedInteger: Eq { const MINUS_ONE: Self; } diff --git a/src/types/code.rs b/src/types/code.rs index 8d38a8fc826..d35ca6c545e 100644 --- a/src/types/code.rs +++ b/src/types/code.rs @@ -1,5 +1,5 @@ -use super::PyAnyMethods as _; use super::PyDict; +use super::{PyAnyMethods as _, PyDictMethods as _}; use crate::ffi_ptr_ext::FfiPtrExt; use crate::py_result_ext::PyResultExt; #[cfg(any(Py_LIMITED_API, PyPy))] @@ -8,7 +8,7 @@ use crate::sync::PyOnceLock; use crate::types::{PyType, PyTypeMethods}; #[cfg(any(Py_LIMITED_API, PyPy))] use crate::Py; -use crate::{ffi, Bound, PyAny, PyErr, PyResult, Python}; +use crate::{ffi, Bound, PyAny, PyResult, Python}; use std::ffi::CStr; /// Represents a Python code object. @@ -127,28 +127,8 @@ impl<'py> PyCodeMethods<'py> for Bound<'py, PyCode> { // - https://github.com/python/cpython/pull/24564 (the same fix in CPython 3.10) // - https://github.com/PyO3/pyo3/issues/3370 let builtins_s = crate::intern!(self.py(), "__builtins__"); - let has_builtins = globals.contains(builtins_s)?; - if !has_builtins { - crate::sync::critical_section::with_critical_section(globals, || { - // check if another thread set __builtins__ while this thread was blocked on the critical section - let has_builtins = globals.contains(builtins_s)?; - if !has_builtins { - // Inherit current builtins. - let builtins = unsafe { ffi::PyEval_GetBuiltins() }; - - // `PyDict_SetItem` doesn't take ownership of `builtins`, but `PyEval_GetBuiltins` - // seems to return a borrowed reference, so no leak here. - if unsafe { - ffi::PyDict_SetItem(globals.as_ptr(), builtins_s.as_ptr(), builtins) - } == -1 - { - return Err(PyErr::fetch(self.py())); - } - } - Ok(()) - })?; - } - + let builtins = unsafe { ffi::PyEval_GetBuiltins().assume_borrowed_unchecked(self.py()) }; + globals.set_default(builtins_s, builtins)?; unsafe { ffi::PyEval_EvalCode(self.as_ptr(), globals.as_ptr(), locals.as_ptr()) .assume_owned_or_err(self.py()) diff --git a/src/types/dict.rs b/src/types/dict.rs index 0f03a217f79..aa5518b4c49 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -207,6 +207,29 @@ pub trait PyDictMethods<'py>: crate::sealed::Sealed { /// This method uses [`PyDict_Merge`](https://docs.python.org/3/c-api/dict.html#c.PyDict_Merge) internally, /// so should have the same performance as `update`. fn update_if_missing(&self, other: &Bound<'_, PyMapping>) -> PyResult<()>; + + /// Inserts `default_value` into this dictionary with a key of `key` if the key is not already present in the + /// dictionary. If the key was inserted, returns Ok(true), otherwise returns Ok(false), indicating the key was + /// already present. If an error happens, returns PyErr. This function uses + /// [`PyDict_SetDefaultRef`](https://docs.python.org/3/c-api/dict.html#c.PyDict_SetDefaultRef) internally. + fn set_default(&self, key: K, default_value: V) -> PyResult + where + K: IntoPyObject<'py>, + V: IntoPyObject<'py>; + + /// Inserts `default_value` into this dictionary with a key of `key` if the key is not already present in the + /// dictionary. If the key was inserted, returns Ok((true, result)), otherwise returns Ok((false, result)) where + /// `result` is the `value` associated with `key` after this function finishes. If an error happens, returns + /// PyErr. This function uses + /// [`PyDict_SetDefaultRef`](https://docs.python.org/3/c-api/dict.html#c.PyDict_SetDefaultRef) internally. + fn set_default_with_result( + &self, + key: K, + default_value: V, + ) -> PyResult<(bool, Bound<'py, PyAny>)> + where + K: IntoPyObject<'py>, + V: IntoPyObject<'py>; } impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> { @@ -385,6 +408,92 @@ impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> { ffi::PyDict_Merge(self.as_ptr(), other.as_ptr(), 0) }) } + + fn set_default(&self, key: K, default_value: V) -> PyResult + where + K: IntoPyObject<'py>, + V: IntoPyObject<'py>, + { + fn inner( + dict: &Bound<'_, PyDict>, + key: Borrowed<'_, '_, PyAny>, + value: Borrowed<'_, '_, PyAny>, + ) -> PyResult { + setdefault_result_from_nonerror_return_code(err::error_on_minusone_with_result( + dict.py(), + unsafe { + ffi::compat::PyDict_SetDefaultRef( + dict.as_ptr(), + key.as_ptr(), + value.as_ptr(), + std::ptr::null_mut(), + ) + }, + )) + } + let py = self.py(); + + inner( + self, + key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(), + default_value + .into_pyobject_or_pyerr(py)? + .into_any() + .as_borrowed(), + ) + } + + fn set_default_with_result( + &self, + key: K, + default_value: V, + ) -> PyResult<(bool, Bound<'py, PyAny>)> + where + K: IntoPyObject<'py>, + V: IntoPyObject<'py>, + { + fn inner<'py>( + dict: &Bound<'_, PyDict>, + key: Borrowed<'_, '_, PyAny>, + value: Borrowed<'_, '_, PyAny>, + py: Python<'py>, + ) -> PyResult<(bool, Bound<'py, PyAny>)> { + let mut result = std::ptr::NonNull::dangling().as_ptr(); + let code = setdefault_result_from_nonerror_return_code( + err::error_on_minusone_with_result(dict.py(), unsafe { + ffi::compat::PyDict_SetDefaultRef( + dict.as_ptr(), + key.as_ptr(), + value.as_ptr(), + &mut result, + ) + }), + )?; + // SAFETY: the interpreter should have set this to a valid owned PyObject pointer + let out_result = unsafe { result.assume_owned_unchecked(py) }; + Ok((code, out_result)) + } + let py = self.py(); + inner( + self, + key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(), + default_value + .into_pyobject_or_pyerr(py)? + .into_any() + .as_borrowed(), + py, + ) + } +} + +fn setdefault_result_from_nonerror_return_code(code: PyResult) -> PyResult { + match code? { + // inserted + 0 => Ok(true), + // not inserted + 1 => Ok(false), + x => panic!("Unknown return value from PyDict_SetDefaultRef: {x}"), + } } impl<'a, 'py> Borrowed<'a, 'py, PyDict> { @@ -1669,4 +1778,54 @@ mod tests { assert_eq!(dict.iter().count(), 3); }) } + + #[test] + fn test_set_default() { + Python::attach(|py| { + let dict = PyDict::new(py); + assert!(matches!(dict.set_default("hello", "world"), Ok(true))); + assert_eq!( + dict.get_item("hello") + .unwrap() + .unwrap() + .extract::() + .unwrap(), + "world" + ); + + assert!(matches!(dict.set_default("hello", "foobar"), Ok(false))); + + // unhashable + let invalid_key = PyList::new(py, vec![0]).unwrap(); + assert!(dict.set_default(invalid_key, "foobar").is_err()); + }) + } + + #[test] + fn test_set_default_with_result() { + Python::attach(|py| { + let dict = PyDict::new(py); + let res = dict.set_default_with_result("hello", "world"); + assert!(res.is_ok()); + let (inserted, value) = res.unwrap(); + assert!(inserted); + assert!(value.extract::().unwrap() == "world"); + assert!( + dict.get_item("hello") + .unwrap() + .unwrap() + .extract::() + .unwrap() + == "world" + ); + + let (inserted, value) = dict.set_default_with_result("hello", "foobar").unwrap(); + assert!(!inserted); + assert_eq!(value.extract::().unwrap(), "world"); + + // unhashable + let invalid_key = PyList::new(py, vec![0]).unwrap(); + assert!(dict.set_default_with_result(invalid_key, "foobar").is_err()); + }) + } }