Skip to content

Commit 6917b35

Browse files
authored
fix type confusion when different #[pyclass] types returned from #[new] (#6062)
* fix type confusion when different `#[pyclass]` types returned from `#[new]` * fixup * newsfragment * don't quote `unsafe` at user code * review feedback
1 parent 49428cb commit 6917b35

15 files changed

Lines changed: 340 additions & 196 deletions

newsfragments/6062.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix type confusion when returning a `#[pyclass]` from a different pyclass' `#[new]` method.

pyo3-macros-backend/src/method.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ pub struct FnSpec<'a> {
449449
pub asyncness: Option<syn::Token![async]>,
450450
pub unsafety: Option<syn::Token![unsafe]>,
451451
pub warnings: Vec<PyFunctionWarning>,
452+
#[cfg_attr(not(feature = "experimental-inspect"), expect(dead_code))]
452453
pub output: syn::ReturnType,
453454
}
454455

pyo3-macros-backend/src/pymethod.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ use crate::{
1616
use crate::{quotes, utils};
1717
use proc_macro2::{Span, TokenStream};
1818
use quote::{format_ident, quote, quote_spanned, ToTokens};
19+
use syn::LitCStr;
1920
use syn::{ext::IdentExt, spanned::Spanned, Field, Ident, Result};
20-
use syn::{parse_quote, LitCStr};
2121

2222
/// Generated code for a single pymethod item.
2323
pub struct MethodAndMethodDef {
@@ -1473,25 +1473,25 @@ fn generate_method_body(
14731473
}
14741474
});
14751475

1476-
let output = if let syn::ReturnType::Type(_, ty) = &spec.output {
1477-
let mut ty = ty.clone();
1478-
utils::elide_lifetimes(&mut ty);
1479-
ty
1480-
} else {
1481-
parse_quote!(())
1476+
let py = syn::Ident::new("py", Span::call_site());
1477+
let initializer = syn::Ident::new("initializer", Span::call_site());
1478+
let slf = syn::Ident::new("_slf", Span::call_site());
1479+
1480+
// Having just this call emitted at the span of the return value helps surface errors
1481+
// if the user passed an invalid return type.
1482+
let conversion = quote_spanned! { *output_span =>
1483+
#pyo3_path::impl_::pymethods::tp_new_impl::<_, #cls>(#py, #initializer, #slf)
14821484
};
1485+
14831486
let body = quote! {
14841487
#text_signature_impl
1485-
1486-
use #pyo3_path::impl_::pyclass::Probe as _;
14871488
#warnings
14881489
#arg_convert
1490+
14891491
let result = #call;
1490-
#pyo3_path::impl_::pymethods::tp_new_impl::<
1491-
_,
1492-
{ #pyo3_path::impl_::pyclass::IsPyClass::<#output>::VALUE },
1493-
{ #pyo3_path::impl_::pyclass::IsInitializerTuple::<#output>::VALUE }
1494-
>(py, result, _slf)
1492+
let value = #pyo3_path::impl_::wrap::OkWrapper::new(&result).ok_wrap(result)?;
1493+
let #initializer = #pyo3_path::impl_::pymethods::tp_new_resolver::<#cls, _>(&value).resolve(value);
1494+
unsafe { #conversion }
14951495
};
14961496
(arg_idents, arg_types, body)
14971497
}

pyo3-macros-backend/src/utils.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ pub(crate) fn locate_tokens_at(tokens: TokenStream, span: Span) -> TokenStream {
380380
///
381381
/// This is useful if `Self` is used in `const` context, where explicit
382382
/// lifetimes are not allowed (yet).
383+
#[cfg(feature = "experimental-inspect")]
383384
pub(crate) fn elide_lifetimes(ty: &mut syn::Type) {
384385
struct ElideLifetimesVisitor;
385386

src/impl_/pyclass.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ use crate::{
55
impl_::{
66
freelist::PyObjectFreeList,
77
pycell::{GetBorrowChecker, PyClassMutability, PyClassObjectBaseLayout},
8-
pyclass_init::PyObjectInit,
98
pymethods::{PyGetterDef, PyMethodDefType},
109
},
10+
internal::pyclass_init::PyObjectInit,
1111
pycell::{impl_::PyClassObjectLayout, PyBorrowError},
1212
types::{any::PyAnyMethods, PyBool},
1313
Borrowed, IntoPyObject, IntoPyObjectExt, Py, PyAny, PyClass, PyClassGuard, PyErr, PyResult,
@@ -1087,6 +1087,10 @@ impl<T> PyClassThreadChecker<T> for ThreadCheckerImpl {
10871087
note = "subclassing native types requires Python >= 3.12 when using the `abi3` feature",
10881088
)
10891089
)]
1090+
#[expect(
1091+
private_bounds,
1092+
reason = "`PyObjectInit` is an internal trait implementation"
1093+
)]
10901094
pub trait PyClassBaseType: Sized {
10911095
type LayoutAsBase: PyClassObjectBaseLayout<Self>;
10921096
type BaseNativeType;

src/impl_/pyclass_init.rs

Lines changed: 2 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -1,172 +1,2 @@
1-
//! Contains initialization utilities for `#[pyclass]`.
2-
use crate::exceptions::PyTypeError;
3-
use crate::ffi_ptr_ext::FfiPtrExt;
4-
use crate::impl_::pyclass::PyClassBaseType;
5-
use crate::internal::get_slot::TP_NEW;
6-
use crate::types::{PyTuple, PyType};
7-
use crate::{ffi, PyClass, PyClassInitializer, PyErr, PyResult, Python};
8-
use crate::{ffi::PyTypeObject, sealed::Sealed, type_object::PyTypeInfo};
9-
use core::marker::PhantomData;
10-
11-
/// Initializer for Python types.
12-
///
13-
/// This trait is intended to use internally for distinguishing `#[pyclass]` and
14-
/// Python native types.
15-
pub trait PyObjectInit<T>: Sized + Sealed {
16-
/// # Safety
17-
/// - `subtype` must be a valid pointer to a type object of T or a subclass.
18-
unsafe fn into_new_object(
19-
self,
20-
py: Python<'_>,
21-
subtype: *mut PyTypeObject,
22-
) -> PyResult<*mut ffi::PyObject>;
23-
}
24-
25-
/// Initializer for Python native types, like `PyDict`.
26-
pub struct PyNativeTypeInitializer<T: PyTypeInfo>(pub PhantomData<T>);
27-
28-
impl<T: PyTypeInfo> PyObjectInit<T> for PyNativeTypeInitializer<T> {
29-
unsafe fn into_new_object(
30-
self,
31-
py: Python<'_>,
32-
subtype: *mut PyTypeObject,
33-
) -> PyResult<*mut ffi::PyObject> {
34-
unsafe fn inner(
35-
py: Python<'_>,
36-
type_ptr: *mut PyTypeObject,
37-
subtype: *mut PyTypeObject,
38-
) -> PyResult<*mut ffi::PyObject> {
39-
let tp_new = unsafe {
40-
type_ptr
41-
.cast::<ffi::PyObject>()
42-
.assume_borrowed_unchecked(py)
43-
.cast_unchecked::<PyType>()
44-
.get_slot(TP_NEW)
45-
.ok_or_else(|| PyTypeError::new_err("base type without tp_new"))?
46-
};
47-
48-
// TODO: make it possible to provide real arguments to the base tp_new
49-
let obj =
50-
unsafe { tp_new(subtype, PyTuple::empty(py).as_ptr(), core::ptr::null_mut()) };
51-
if obj.is_null() {
52-
Err(PyErr::fetch(py))
53-
} else {
54-
Ok(obj)
55-
}
56-
}
57-
unsafe { inner(py, T::type_object_raw(py), subtype) }
58-
}
59-
}
60-
61-
pub trait PyClassInit<'py, const IS_PYCLASS: bool, const IS_INITIALIZER_TUPLE: bool>:
62-
seal_pyclass_init::Sealed<'py, IS_PYCLASS, IS_INITIALIZER_TUPLE>
63-
{
64-
fn init(
65-
self,
66-
cls: crate::Borrowed<'_, 'py, crate::types::PyType>,
67-
) -> PyResult<crate::Bound<'py, crate::PyAny>>;
68-
}
69-
70-
mod seal_pyclass_init {
71-
use crate::impl_::pyclass::{self, PyClassBaseType};
72-
use crate::impl_::pyclass_init::{PyClassInit, PyNativeTypeInitializer};
73-
use crate::{PyClass, PyClassInitializer, PyErr};
74-
75-
pub trait Sealed<'py, const IS_PYCLASS: bool, const IS_INITIALIZER_TUPLE: bool> {}
76-
77-
impl<'py, T> Sealed<'py, false, false> for T where T: crate::IntoPyObject<'py> {}
78-
impl<'py, T> Sealed<'py, true, false> for T
79-
where
80-
T: crate::PyClass,
81-
T::BaseType: pyclass::PyClassBaseType<Initializer = PyNativeTypeInitializer<T::BaseType>>,
82-
{
83-
}
84-
impl<'py, T, E, const IS_PYCLASS: bool, const IS_INITIALIZER_TUPLE: bool>
85-
Sealed<'py, IS_PYCLASS, IS_INITIALIZER_TUPLE> for Result<T, E>
86-
where
87-
T: PyClassInit<'py, IS_PYCLASS, IS_INITIALIZER_TUPLE>,
88-
E: Into<PyErr>,
89-
{
90-
}
91-
impl<'py, T> Sealed<'py, false, false> for PyClassInitializer<T> where T: PyClass {}
92-
impl<'py, S, B> Sealed<'py, false, true> for (S, B)
93-
where
94-
S: PyClass<BaseType = B>,
95-
B: PyClass + PyClassBaseType<Initializer = PyClassInitializer<B>>,
96-
B::BaseType: PyClassBaseType<Initializer = PyNativeTypeInitializer<B::BaseType>>,
97-
{
98-
}
99-
}
100-
101-
impl<'py, T> PyClassInit<'py, false, false> for T
102-
where
103-
T: crate::IntoPyObject<'py>,
104-
{
105-
fn init(
106-
self,
107-
cls: crate::Borrowed<'_, 'py, crate::types::PyType>,
108-
) -> PyResult<crate::Bound<'py, crate::PyAny>> {
109-
self.into_pyobject(cls.py())
110-
.map(crate::BoundObject::into_any)
111-
.map(crate::BoundObject::into_bound)
112-
.map_err(Into::into)
113-
}
114-
}
115-
116-
impl<'py, T> PyClassInit<'py, true, false> for T
117-
where
118-
T: crate::PyClass,
119-
T::BaseType:
120-
super::pyclass::PyClassBaseType<Initializer = PyNativeTypeInitializer<T::BaseType>>,
121-
{
122-
fn init(
123-
self,
124-
cls: crate::Borrowed<'_, 'py, crate::types::PyType>,
125-
) -> PyResult<crate::Bound<'py, crate::PyAny>> {
126-
PyClassInitializer::from(self).init(cls)
127-
}
128-
}
129-
130-
impl<'py, T, E, const IS_PYCLASS: bool, const IS_INITIALIZER_TUPLE: bool>
131-
PyClassInit<'py, IS_PYCLASS, IS_INITIALIZER_TUPLE> for Result<T, E>
132-
where
133-
T: PyClassInit<'py, IS_PYCLASS, IS_INITIALIZER_TUPLE>,
134-
E: Into<PyErr>,
135-
{
136-
fn init(
137-
self,
138-
cls: crate::Borrowed<'_, 'py, crate::types::PyType>,
139-
) -> PyResult<crate::Bound<'py, crate::PyAny>> {
140-
self.map_err(Into::into)?.init(cls)
141-
}
142-
}
143-
144-
impl<'py, T> PyClassInit<'py, false, false> for PyClassInitializer<T>
145-
where
146-
T: PyClass,
147-
{
148-
fn init(
149-
self,
150-
cls: crate::Borrowed<'_, 'py, crate::types::PyType>,
151-
) -> PyResult<crate::Bound<'py, crate::PyAny>> {
152-
unsafe {
153-
self.create_class_object_of_type(cls.py(), cls.as_ptr().cast())
154-
.map(crate::Bound::into_any)
155-
}
156-
}
157-
}
158-
159-
impl<'py, S, B> PyClassInit<'py, false, true> for (S, B)
160-
where
161-
S: PyClass<BaseType = B>,
162-
B: PyClass + PyClassBaseType<Initializer = PyClassInitializer<B>>,
163-
B::BaseType: PyClassBaseType<Initializer = PyNativeTypeInitializer<B::BaseType>>,
164-
{
165-
fn init(
166-
self,
167-
cls: crate::Borrowed<'_, 'py, crate::types::PyType>,
168-
) -> PyResult<crate::Bound<'py, crate::PyAny>> {
169-
let (sub, base) = self;
170-
PyClassInitializer::from(base).add_subclass(sub).init(cls)
171-
}
172-
}
1+
/// Re-exported so that macros can name this internal type.
2+
pub use crate::internal::pyclass_init::PyNativeTypeInitializer;

src/impl_/pymethods.rs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ use crate::impl_::callback::IntoPyCallbackOutput;
33
use crate::impl_::panic::PanicTrap;
44
use crate::impl_::pycell::PyClassObjectBaseLayout;
55
use crate::internal::get_slot::{get_slot, TP_BASE, TP_CLEAR, TP_TRAVERSE};
6+
use crate::internal::pyclass_init::PyClassInit;
67
use crate::internal::state::ForbidAttaching;
78
use crate::pycell::impl_::{PyClassBorrowChecker as _, PyClassObjectLayout};
89
use crate::types::PyType;
9-
use crate::{ffi, Bound, Py, PyAny, PyClass, PyErr, PyResult, PyTraverseError, PyVisit, Python};
10+
use crate::{
11+
ffi, Borrowed, Bound, Py, PyAny, PyClass, PyErr, PyResult, PyTraverseError, PyVisit, Python,
12+
};
1013
use core::ffi::CStr;
1114
use core::ffi::{c_int, c_void};
1215
use core::fmt;
@@ -688,18 +691,26 @@ pub trait AsyncIterResultOptionKind {
688691

689692
impl<Value, Error> AsyncIterResultOptionKind for Result<Option<Value>, Error> {}
690693

691-
pub unsafe fn tp_new_impl<'py, T, const IS_PYCLASS: bool, const IS_INITIALIZER_TUPLE: bool>(
694+
/// Re-exported so that `#[new]` generated code can resolve the type tag for `tp_new_impl`
695+
pub use crate::internal::pyclass_init::tp_new_resolver;
696+
697+
#[expect(
698+
private_bounds,
699+
reason = "`PyClassInit` is not a public trait, bound exist for diagnostics"
700+
)]
701+
/// # SAFETY
702+
/// - `cls` must be the type object for `ClassT` (or a subclass)
703+
pub unsafe fn tp_new_impl<'py, InitializerT, ClassT>(
692704
py: Python<'py>,
693-
obj: T,
705+
initializer: InitializerT,
694706
cls: *mut ffi::PyTypeObject,
695707
) -> PyResult<*mut ffi::PyObject>
696708
where
697-
T: super::pyclass_init::PyClassInit<'py, IS_PYCLASS, IS_INITIALIZER_TUPLE>,
709+
InitializerT: PyClassInit<'py, ClassT>,
698710
{
699-
unsafe {
700-
obj.init(crate::Borrowed::from_ptr_unchecked(py, cls.cast()).cast_unchecked())
701-
.map(Bound::into_ptr)
702-
}
711+
// SAFETY: caller has guaranteed `cls` is the correct object
712+
unsafe { initializer.init(Borrowed::from_ptr_unchecked(py, cls.cast()).cast_unchecked()) }
713+
.map(Bound::into_ptr)
703714
}
704715

705716
#[cfg(test)]

src/impl_/wrap.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,34 @@ impl<T> SomeWrap<T> for Option<T> {
2323
}
2424
}
2525

26+
pub struct OkWrapper<T>(OkWrapperInner<T>);
27+
pub struct OkWrapperInner<T>(PhantomData<T>);
28+
29+
impl<T> OkWrapper<T> {
30+
pub fn new(_: &T) -> Self {
31+
Self(OkWrapperInner(PhantomData))
32+
}
33+
}
34+
35+
impl<T> Deref for OkWrapper<T> {
36+
type Target = OkWrapperInner<T>;
37+
fn deref(&self) -> &Self::Target {
38+
&self.0
39+
}
40+
}
41+
42+
impl<T, E> OkWrapper<Result<T, E>> {
43+
pub fn ok_wrap(&self, value: Result<T, E>) -> Result<T, E> {
44+
value
45+
}
46+
}
47+
48+
impl<T> OkWrapperInner<T> {
49+
pub fn ok_wrap(&self, value: T) -> Result<T, Infallible> {
50+
Ok(value)
51+
}
52+
}
53+
2654
// Hierarchy of conversions used in the function return type machinery
2755
pub struct Converter<T>(EmptyTupleConverter<T>);
2856
pub struct EmptyTupleConverter<T>(IntoPyObjectConverter<T>);
@@ -155,4 +183,15 @@ mod tests {
155183
let b: Option<u8> = SomeWrap::wrap(None);
156184
assert_eq!(b, None);
157185
}
186+
187+
#[test]
188+
fn wrap_result() {
189+
let a = 42;
190+
let Ok(a) = OkWrapper::new(&a).ok_wrap(a);
191+
assert_eq!(a, 42);
192+
193+
let b = Result::<_, String>::Ok(42);
194+
let b = OkWrapper::new(&b).ok_wrap(b);
195+
assert_eq!(b, Ok(42));
196+
}
158197
}

src/internal.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
//! Holding place for code which is not intended to be reachable from outside of PyO3.
22
33
pub(crate) mod get_slot;
4+
pub(crate) mod pyclass_init;
45
pub(crate) mod state;

0 commit comments

Comments
 (0)