Skip to content

Commit 8ddcf9c

Browse files
ngoldbaumdavidhewitt
authored andcommitted
Add bindings for PyDict_SetDefaultRef and use them in PyCode::run (PyO3#5955)
* Add bindings for PyDict_SetDefaultRef and use them in PyCode::run * fix conditional compilation * simplify conditional predicates * minor touchups * add changelog entries * pass NULL for the result * Add new PyDictMethods API wrapping PyDict_SetDefaultRef * gate new test for 3.9 and older limited API * fix clippy and typos * Apply code review comments from David * typo fix
1 parent d2f9e1e commit 8ddcf9c

8 files changed

Lines changed: 238 additions & 25 deletions

File tree

newsfragments/5955.added.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
* Added FFI bindings for `PyDict_SetDefaultRef` for Python 3.13 and newer and
2+
3.15 and newer in the limited API. Also added a compat shim for older Python
3+
versions.
4+
5+
* Added `PyDictMethods::set_default` and `PyDictMethods::set_default_ref` to
6+
allow atomically setting default values in a PyDict.

newsfragments/5955.changed.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
`PyCode::run` uses PyDict_SetDefaultRef instead of a critical section to ensure
2+
thread safety when it adds a reference to `__builtins__` to the globals dict.

pyo3-ffi/src/compat/py_3_13.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,50 @@ compat_function!(
130130
crate::_PyThreadState_UncheckedGet()
131131
}
132132
);
133+
134+
compat_function!(
135+
originally_defined_for(all(Py_3_13, any(not(Py_LIMITED_API), Py_3_15)));
136+
137+
#[inline]
138+
pub unsafe fn PyDict_SetDefaultRef(
139+
mp: *mut crate::PyObject,
140+
key: *mut crate::PyObject,
141+
default_value: *mut crate::PyObject,
142+
result: *mut *mut crate::PyObject,
143+
) -> std::ffi::c_int {
144+
use crate::{
145+
compat::{PyDict_GetItemRef, Py_NewRef},
146+
PyDict_SetItem, PyObject, Py_DECREF,
147+
};
148+
let mut value: *mut PyObject = std::ptr::null_mut();
149+
if PyDict_GetItemRef(mp, key, &mut value) < 0 {
150+
// get error
151+
if !result.is_null() {
152+
*result = std::ptr::null_mut();
153+
}
154+
return -1;
155+
}
156+
if !value.is_null() {
157+
// present
158+
if !result.is_null() {
159+
*result = value;
160+
} else {
161+
Py_DECREF(value);
162+
}
163+
return 1;
164+
}
165+
166+
// missing, set the item
167+
if PyDict_SetItem(mp, key, default_value) < 0 {
168+
// set error
169+
if !result.is_null() {
170+
*result = std::ptr::null_mut();
171+
}
172+
return -1;
173+
}
174+
if !result.is_null() {
175+
*result = Py_NewRef(default_value);
176+
}
177+
0
178+
}
179+
);

pyo3-ffi/src/cpython/dictobject.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ pub struct PyDictObject {
4343
// skipped private _PyDict_GetItemStringWithError
4444

4545
// skipped PyDict_SetDefault
46-
// skipped PyDict_SetDefaultRef
4746

4847
// skipped PyDict_GET_SIZE
4948
// skipped PyDict_ContainsString

pyo3-ffi/src/dictobject.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ extern_libpython! {
7878
key: *const c_char,
7979
result: *mut *mut PyObject,
8080
) -> c_int;
81+
#[cfg(all(Py_3_13, any(not(Py_LIMITED_API), Py_3_15)))]
82+
pub fn PyDict_SetDefaultRef(
83+
mp: *mut PyObject,
84+
key: *mut PyObject,
85+
default_value: *mut PyObject,
86+
result: *mut *mut PyObject,
87+
) -> c_int;
8188
// skipped 3.10 / ex-non-limited PyObject_GenericGetDict
8289
}
8390

src/err/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,19 @@ pub(crate) fn error_on_minusone<T: SignedInteger>(py: Python<'_>, result: T) ->
761761
}
762762
}
763763

764+
/// Returns Ok wrapping the result if the error code is not -1.
765+
#[inline]
766+
pub(crate) fn error_on_minusone_with_result<T: SignedInteger>(
767+
py: Python<'_>,
768+
result: T,
769+
) -> PyResult<T> {
770+
if result != T::MINUS_ONE {
771+
Ok(result)
772+
} else {
773+
Err(PyErr::fetch(py))
774+
}
775+
}
776+
764777
pub(crate) trait SignedInteger: Eq {
765778
const MINUS_ONE: Self;
766779
}

src/types/code.rs

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
use super::PyAnyMethods as _;
21
use super::PyDict;
2+
use super::{PyAnyMethods as _, PyDictMethods as _};
33
use crate::ffi_ptr_ext::FfiPtrExt;
44
use crate::py_result_ext::PyResultExt;
55
#[cfg(any(Py_LIMITED_API, PyPy))]
@@ -8,7 +8,7 @@ use crate::sync::PyOnceLock;
88
use crate::types::{PyType, PyTypeMethods};
99
#[cfg(any(Py_LIMITED_API, PyPy))]
1010
use crate::Py;
11-
use crate::{ffi, Bound, PyAny, PyErr, PyResult, Python};
11+
use crate::{ffi, Bound, PyAny, PyResult, Python};
1212
use std::ffi::CStr;
1313

1414
/// Represents a Python code object.
@@ -127,28 +127,8 @@ impl<'py> PyCodeMethods<'py> for Bound<'py, PyCode> {
127127
// - https://github.com/python/cpython/pull/24564 (the same fix in CPython 3.10)
128128
// - https://github.com/PyO3/pyo3/issues/3370
129129
let builtins_s = crate::intern!(self.py(), "__builtins__");
130-
let has_builtins = globals.contains(builtins_s)?;
131-
if !has_builtins {
132-
crate::sync::critical_section::with_critical_section(globals, || {
133-
// check if another thread set __builtins__ while this thread was blocked on the critical section
134-
let has_builtins = globals.contains(builtins_s)?;
135-
if !has_builtins {
136-
// Inherit current builtins.
137-
let builtins = unsafe { ffi::PyEval_GetBuiltins() };
138-
139-
// `PyDict_SetItem` doesn't take ownership of `builtins`, but `PyEval_GetBuiltins`
140-
// seems to return a borrowed reference, so no leak here.
141-
if unsafe {
142-
ffi::PyDict_SetItem(globals.as_ptr(), builtins_s.as_ptr(), builtins)
143-
} == -1
144-
{
145-
return Err(PyErr::fetch(self.py()));
146-
}
147-
}
148-
Ok(())
149-
})?;
150-
}
151-
130+
let builtins = unsafe { ffi::PyEval_GetBuiltins().assume_borrowed_unchecked(self.py()) };
131+
globals.set_default(builtins_s, builtins)?;
152132
unsafe {
153133
ffi::PyEval_EvalCode(self.as_ptr(), globals.as_ptr(), locals.as_ptr())
154134
.assume_owned_or_err(self.py())

src/types/dict.rs

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,29 @@ pub trait PyDictMethods<'py>: crate::sealed::Sealed {
207207
/// This method uses [`PyDict_Merge`](https://docs.python.org/3/c-api/dict.html#c.PyDict_Merge) internally,
208208
/// so should have the same performance as `update`.
209209
fn update_if_missing(&self, other: &Bound<'_, PyMapping>) -> PyResult<()>;
210+
211+
/// Inserts `default_value` into this dictionary with a key of `key` if the key is not already present in the
212+
/// dictionary. If the key was inserted, returns Ok(true), otherwise returns Ok(false), indicating the key was
213+
/// already present. If an error happens, returns PyErr. This function uses
214+
/// [`PyDict_SetDefaultRef`](https://docs.python.org/3/c-api/dict.html#c.PyDict_SetDefaultRef) internally.
215+
fn set_default<K, V>(&self, key: K, default_value: V) -> PyResult<bool>
216+
where
217+
K: IntoPyObject<'py>,
218+
V: IntoPyObject<'py>;
219+
220+
/// Inserts `default_value` into this dictionary with a key of `key` if the key is not already present in the
221+
/// dictionary. If the key was inserted, returns Ok((true, result)), otherwise returns Ok((false, result)) where
222+
/// `result` is the `value` associated with `key` after this function finishes. If an error happens, returns
223+
/// PyErr. This function uses
224+
/// [`PyDict_SetDefaultRef`](https://docs.python.org/3/c-api/dict.html#c.PyDict_SetDefaultRef) internally.
225+
fn set_default_with_result<K, V>(
226+
&self,
227+
key: K,
228+
default_value: V,
229+
) -> PyResult<(bool, Bound<'py, PyAny>)>
230+
where
231+
K: IntoPyObject<'py>,
232+
V: IntoPyObject<'py>;
210233
}
211234

212235
impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> {
@@ -385,6 +408,92 @@ impl<'py> PyDictMethods<'py> for Bound<'py, PyDict> {
385408
ffi::PyDict_Merge(self.as_ptr(), other.as_ptr(), 0)
386409
})
387410
}
411+
412+
fn set_default<K, V>(&self, key: K, default_value: V) -> PyResult<bool>
413+
where
414+
K: IntoPyObject<'py>,
415+
V: IntoPyObject<'py>,
416+
{
417+
fn inner(
418+
dict: &Bound<'_, PyDict>,
419+
key: Borrowed<'_, '_, PyAny>,
420+
value: Borrowed<'_, '_, PyAny>,
421+
) -> PyResult<bool> {
422+
setdefault_result_from_nonerror_return_code(err::error_on_minusone_with_result(
423+
dict.py(),
424+
unsafe {
425+
ffi::compat::PyDict_SetDefaultRef(
426+
dict.as_ptr(),
427+
key.as_ptr(),
428+
value.as_ptr(),
429+
std::ptr::null_mut(),
430+
)
431+
},
432+
))
433+
}
434+
let py = self.py();
435+
436+
inner(
437+
self,
438+
key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
439+
default_value
440+
.into_pyobject_or_pyerr(py)?
441+
.into_any()
442+
.as_borrowed(),
443+
)
444+
}
445+
446+
fn set_default_with_result<K, V>(
447+
&self,
448+
key: K,
449+
default_value: V,
450+
) -> PyResult<(bool, Bound<'py, PyAny>)>
451+
where
452+
K: IntoPyObject<'py>,
453+
V: IntoPyObject<'py>,
454+
{
455+
fn inner<'py>(
456+
dict: &Bound<'_, PyDict>,
457+
key: Borrowed<'_, '_, PyAny>,
458+
value: Borrowed<'_, '_, PyAny>,
459+
py: Python<'py>,
460+
) -> PyResult<(bool, Bound<'py, PyAny>)> {
461+
let mut result = std::ptr::NonNull::dangling().as_ptr();
462+
let code = setdefault_result_from_nonerror_return_code(
463+
err::error_on_minusone_with_result(dict.py(), unsafe {
464+
ffi::compat::PyDict_SetDefaultRef(
465+
dict.as_ptr(),
466+
key.as_ptr(),
467+
value.as_ptr(),
468+
&mut result,
469+
)
470+
}),
471+
)?;
472+
// SAFETY: the interpreter should have set this to a valid owned PyObject pointer
473+
let out_result = unsafe { result.assume_owned_unchecked(py) };
474+
Ok((code, out_result))
475+
}
476+
let py = self.py();
477+
inner(
478+
self,
479+
key.into_pyobject_or_pyerr(py)?.into_any().as_borrowed(),
480+
default_value
481+
.into_pyobject_or_pyerr(py)?
482+
.into_any()
483+
.as_borrowed(),
484+
py,
485+
)
486+
}
487+
}
488+
489+
fn setdefault_result_from_nonerror_return_code(code: PyResult<std::ffi::c_int>) -> PyResult<bool> {
490+
match code? {
491+
// inserted
492+
0 => Ok(true),
493+
// not inserted
494+
1 => Ok(false),
495+
x => panic!("Unknown return value from PyDict_SetDefaultRef: {x}"),
496+
}
388497
}
389498

390499
impl<'a, 'py> Borrowed<'a, 'py, PyDict> {
@@ -1669,4 +1778,54 @@ mod tests {
16691778
assert_eq!(dict.iter().count(), 3);
16701779
})
16711780
}
1781+
1782+
#[test]
1783+
fn test_set_default() {
1784+
Python::attach(|py| {
1785+
let dict = PyDict::new(py);
1786+
assert!(matches!(dict.set_default("hello", "world"), Ok(true)));
1787+
assert_eq!(
1788+
dict.get_item("hello")
1789+
.unwrap()
1790+
.unwrap()
1791+
.extract::<String>()
1792+
.unwrap(),
1793+
"world"
1794+
);
1795+
1796+
assert!(matches!(dict.set_default("hello", "foobar"), Ok(false)));
1797+
1798+
// unhashable
1799+
let invalid_key = PyList::new(py, vec![0]).unwrap();
1800+
assert!(dict.set_default(invalid_key, "foobar").is_err());
1801+
})
1802+
}
1803+
1804+
#[test]
1805+
fn test_set_default_with_result() {
1806+
Python::attach(|py| {
1807+
let dict = PyDict::new(py);
1808+
let res = dict.set_default_with_result("hello", "world");
1809+
assert!(res.is_ok());
1810+
let (inserted, value) = res.unwrap();
1811+
assert!(inserted);
1812+
assert!(value.extract::<String>().unwrap() == "world");
1813+
assert!(
1814+
dict.get_item("hello")
1815+
.unwrap()
1816+
.unwrap()
1817+
.extract::<String>()
1818+
.unwrap()
1819+
== "world"
1820+
);
1821+
1822+
let (inserted, value) = dict.set_default_with_result("hello", "foobar").unwrap();
1823+
assert!(!inserted);
1824+
assert_eq!(value.extract::<String>().unwrap(), "world");
1825+
1826+
// unhashable
1827+
let invalid_key = PyList::new(py, vec![0]).unwrap();
1828+
assert!(dict.set_default_with_result(invalid_key, "foobar").is_err());
1829+
})
1830+
}
16721831
}

0 commit comments

Comments
 (0)