Skip to content

Commit 3932d02

Browse files
CopilotMatthieuDartiailh
authored andcommitted
Implement trusted self conversion for all extension-type method and slot wrappers
1 parent 71c9365 commit 3932d02

10 files changed

Lines changed: 619 additions & 141 deletions

pyo3-macros-backend/src/method.rs

Lines changed: 104 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ impl FnType {
263263
&self,
264264
cls: Option<&syn::Type>,
265265
error_mode: ExtractErrorMode,
266-
descriptor_slot_receiver: bool,
266+
self_conversion: SelfConversionPolicy,
267267
holders: &mut Holders,
268268
ctx: &Ctx,
269269
) -> Option<TokenStream> {
@@ -273,7 +273,7 @@ impl FnType {
273273
Some(st.receiver(
274274
cls.expect("no class given for Fn with a \"self\" receiver"),
275275
error_mode,
276-
descriptor_slot_receiver,
276+
self_conversion,
277277
holders,
278278
ctx,
279279
))
@@ -322,6 +322,41 @@ pub enum SelfType {
322322
},
323323
}
324324

325+
/// Receiver conversion policy for extension-type method wrappers.
326+
///
327+
/// Controls whether the `self` receiver is validated with a runtime type check
328+
/// (`Checked`) or treated as trusted and cast directly without checking
329+
/// (`Trusted`).
330+
///
331+
/// # Invariant
332+
///
333+
/// The `Trusted` path is valid due to CPython's slot/method receiver contract:
334+
/// when CPython dispatches a method call on an extension type — whether through
335+
/// a type slot or through `tp_methods` — the receiver is guaranteed to be an
336+
/// instance of the owning type (or a compatible subtype). For `tp_methods`
337+
/// entries, CPython's method-wrapper descriptor enforces this before the C
338+
/// function is reached.
339+
///
340+
/// `Checked` should be used in cases where that guarantee does not hold:
341+
/// - Standalone `#[pyfunction]`s (no class receiver).
342+
/// - Number-protocol binary operator fragments (`__add__`, `__radd__`, …,
343+
/// `__pow__`, `__rpow__`): CPython combines the forward and reflected
344+
/// fragments into a single `nb_add`/`nb_power` slot, and the runtime helper
345+
/// may call the reflected fragment with the operands swapped, meaning `_slf`
346+
/// can arrive with a non-class type. The existing
347+
/// `ExtractErrorMode::NotImplemented` behaviour on type mismatch is preserved
348+
/// by using `Checked` for those fragments.
349+
#[derive(Clone, Copy, Debug)]
350+
pub enum SelfConversionPolicy {
351+
/// The receiver's type is guaranteed by CPython's slot/method dispatch contract.
352+
/// Used for all extension-type method and slot entrypoints.
353+
Trusted,
354+
/// The receiver's type is verified at runtime. Used for standalone functions
355+
/// and number-protocol binary operator fragments where the CPython dispatch
356+
/// contract does not guarantee the receiver type.
357+
Checked,
358+
}
359+
325360
#[derive(Clone, Copy)]
326361
pub enum ExtractErrorMode {
327362
NotImplemented,
@@ -348,7 +383,7 @@ impl SelfType {
348383
&self,
349384
cls: &syn::Type,
350385
error_mode: ExtractErrorMode,
351-
descriptor_slot_receiver: bool,
386+
self_conversion: SelfConversionPolicy,
352387
holders: &mut Holders,
353388
ctx: &Ctx,
354389
) -> TokenStream {
@@ -370,22 +405,47 @@ impl SelfType {
370405
};
371406
let arg =
372407
quote! { unsafe { #pyo3_path::impl_::extract_argument::#cast_fn(#py, #slf) } };
373-
let method = if *mutable {
374-
syn::Ident::new("extract_pyclass_ref_mut", *span)
375-
} else {
376-
syn::Ident::new("extract_pyclass_ref", *span)
377-
};
378408
let holder = holders.push_holder(*span);
379409
let pyo3_path = pyo3_path.to_tokens_spanned(*span);
380-
error_mode.handle_error(
381-
quote_spanned! { *span =>
382-
#pyo3_path::impl_::extract_argument::#method::<#cls>(
383-
#arg,
384-
&mut #holder,
410+
match self_conversion {
411+
SelfConversionPolicy::Trusted => {
412+
let method = if *mutable {
413+
syn::Ident::new("extract_pyclass_ref_mut_trusted", *span)
414+
} else {
415+
syn::Ident::new("extract_pyclass_ref_trusted", *span)
416+
};
417+
// Use `quote!` (not `quote_spanned!`) for the `unsafe` block so that
418+
// the `unsafe` keyword has `Span::call_site()` and does not inherit the
419+
// user's code span. This prevents triggering `#![forbid(unsafe_code)]`
420+
// in user crates (see the analogous comment in `impl_py_getter_def`).
421+
// Safety: slot wrappers are only installed on the extension type itself.
422+
// CPython's slot dispatch contract ensures the receiver is an instance
423+
// of the correct type before invoking the slot.
424+
let trusted_call = quote! {
425+
unsafe { #pyo3_path::impl_::extract_argument::#method::<#cls>(
426+
#arg,
427+
&mut #holder,
428+
) }
429+
};
430+
error_mode.handle_error(trusted_call, ctx)
431+
}
432+
SelfConversionPolicy::Checked => {
433+
let method = if *mutable {
434+
syn::Ident::new("extract_pyclass_ref_mut", *span)
435+
} else {
436+
syn::Ident::new("extract_pyclass_ref", *span)
437+
};
438+
error_mode.handle_error(
439+
quote_spanned! { *span =>
440+
#pyo3_path::impl_::extract_argument::#method::<#cls>(
441+
#arg,
442+
&mut #holder,
443+
)
444+
},
445+
ctx,
385446
)
386-
},
387-
ctx,
388-
)
447+
}
448+
}
389449
}
390450
SelfType::TryFromBoundRef { span, non_null } => {
391451
let bound_ref = if *non_null {
@@ -394,22 +454,27 @@ impl SelfType {
394454
quote! { unsafe { #pyo3_path::Bound::ref_from_ptr(#py, &#slf) } }
395455
};
396456
let pyo3_path = pyo3_path.to_tokens_spanned(*span);
397-
let receiver = if descriptor_slot_receiver {
398-
quote_spanned! { *span =>
399-
// Safety: descriptor slot wrappers are only installed on the descriptor
400-
// type itself. CPython calls those slots with `self` set to the
401-
// descriptor object found during lookup, and explicit Python calls to
402-
// `__get__`, `__set__`, and `__delete__` first pass through CPython's
403-
// slot wrapper, which rejects receivers of the wrong type before
404-
// reaching this generated wrapper.
405-
::std::result::Result::<_, #pyo3_path::PyErr>::Ok(unsafe {
406-
#bound_ref.cast_unchecked::<#cls>()
407-
})
457+
let receiver = match self_conversion {
458+
SelfConversionPolicy::Trusted => {
459+
// Use `quote!` (not `quote_spanned!`) for the inner `unsafe` block so
460+
// that it has `Span::call_site()` and does not trigger
461+
// `#![forbid(unsafe_code)]` in user crates.
462+
// Safety: slot wrappers are only installed on the extension type
463+
// itself. CPython's slot dispatch contract ensures the receiver is
464+
// an instance of the correct type (or a compatible subtype) before
465+
// invoking the slot.
466+
let cast = quote! {
467+
unsafe { #bound_ref.cast_unchecked::<#cls>() }
468+
};
469+
quote_spanned! { *span =>
470+
::std::result::Result::<_, #pyo3_path::PyErr>::Ok(#cast)
471+
}
408472
}
409-
} else {
410-
quote_spanned! { *span =>
411-
#bound_ref.cast::<#cls>()
412-
.map_err(::std::convert::Into::<#pyo3_path::PyErr>::into)
473+
SelfConversionPolicy::Checked => {
474+
quote_spanned! { *span =>
475+
#bound_ref.cast::<#cls>()
476+
.map_err(::std::convert::Into::<#pyo3_path::PyErr>::into)
477+
}
413478
}
414479
};
415480
error_mode.handle_error(
@@ -697,6 +762,7 @@ impl<'a> FnSpec<'a> {
697762
ident: &proc_macro2::Ident,
698763
cls: Option<&syn::Type>,
699764
convention: CallingConvention,
765+
self_conversion: SelfConversionPolicy,
700766
ctx: &Ctx,
701767
) -> Result<TokenStream> {
702768
let Ctx {
@@ -719,9 +785,13 @@ impl<'a> FnSpec<'a> {
719785
}
720786

721787
let rust_call = |args: Vec<TokenStream>, mut holders: Holders| {
722-
let self_arg = self
723-
.tp
724-
.self_arg(cls, ExtractErrorMode::Raise, false, &mut holders, ctx);
788+
let self_arg = self.tp.self_arg(
789+
cls,
790+
ExtractErrorMode::Raise,
791+
self_conversion,
792+
&mut holders,
793+
ctx,
794+
);
725795
let init_holders = holders.init_holders(ctx);
726796

727797
// We must assign the output_span to the return value of the call,

pyo3-macros-backend/src/pyfunction.rs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::{
1010
self, get_pyo3_options, take_attributes, take_pyo3_options, CrateAttribute,
1111
FromPyWithAttribute, NameAttribute, TextSignatureAttribute,
1212
},
13-
method::{self, CallingConvention, FnArg},
13+
method::{self, CallingConvention, FnArg, SelfConversionPolicy},
1414
pymethod::check_generic,
1515
};
1616
use proc_macro2::{Span, TokenStream};
@@ -430,7 +430,13 @@ pub fn impl_wrap_pyfunction(
430430
);
431431
}
432432
let calling_convention = CallingConvention::from_signature(&spec.signature);
433-
let wrapper = spec.get_wrapper_function(&wrapper_ident, None, calling_convention, ctx)?;
433+
let wrapper = spec.get_wrapper_function(
434+
&wrapper_ident,
435+
None,
436+
calling_convention,
437+
SelfConversionPolicy::Checked,
438+
ctx,
439+
)?;
434440
let methoddef = spec.get_methoddef(
435441
wrapper_ident,
436442
spec.get_doc(&func.attrs).as_ref(),

0 commit comments

Comments
 (0)