Skip to content

Commit 90f4ea5

Browse files
refactor: centralize free-threaded layout branching
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent a37ba3d commit 90f4ea5

File tree

2 files changed

+60
-42
lines changed

2 files changed

+60
-42
lines changed

src/C/extras.jl

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,57 @@
11
asptr(x) = Base.unsafe_convert(PyPtr, x)
22

3-
Py_Type(x) = Base.GC.@preserve x begin
4-
if CTX.is_free_threaded
5-
PyPtr(UnsafePtr{PyObjectFT}(asptr(x)).type[!])
3+
# Free-threaded CPython builds ("3.14t") currently have different C struct layouts,
4+
# but there is no stable ABI yet. To keep the code manageable, we centralize the
5+
# branching in a single macro that rewrites type names in the expression.
6+
const _FT_TYPE_REPLACEMENTS = Dict{Symbol,Symbol}(
7+
:PyObject => :PyObjectFT,
8+
:PyVarObject => :PyVarObjectFT,
9+
:PyMemoryViewObject => :PyMemoryViewObjectFT,
10+
:PySimpleObject => :PySimpleObjectFT,
11+
# Used from JlWrap/C.jl via `C.@ft`.
12+
:PyJuliaValueObject => :PyJuliaValueObjectFT,
13+
)
14+
15+
_ft_replace(sym::Symbol) = get(_FT_TYPE_REPLACEMENTS, sym, sym)
16+
17+
function _ft_transform(ex)
18+
if ex isa Symbol
19+
return _ft_replace(ex)
20+
elseif ex isa QuoteNode
21+
v = ex.value
22+
return v isa Symbol ? QuoteNode(_ft_replace(v)) : ex
23+
elseif ex isa Expr
24+
# Handle dotted refs like `C.PyObject` (Expr(:., ...)).
25+
if ex.head === :. && length(ex.args) == 2 && ex.args[2] isa QuoteNode && ex.args[2].value isa Symbol
26+
return Expr(:., _ft_transform(ex.args[1]), QuoteNode(_ft_replace(ex.args[2].value)))
27+
end
28+
return Expr(ex.head, map(_ft_transform, ex.args)...)
629
else
7-
PyPtr(UnsafePtr{PyObject}(asptr(x)).type[!])
30+
return ex
831
end
932
end
1033

34+
"""
35+
@ft expr
36+
37+
Evaluate `expr`, but when `CTX.is_free_threaded` is true (CPython "free-threaded"
38+
builds), rewrite internal type names like `PyObject` → `PyObjectFT` inside the
39+
expression.
40+
41+
This keeps free-threaded branching centralized, so we don't scatter `if
42+
CTX.is_free_threaded` throughout the code.
43+
"""
44+
macro ft(ex)
45+
ex_ft = _ft_transform(ex)
46+
return :(if CTX.is_free_threaded
47+
$(esc(ex_ft))
48+
else
49+
$(esc(ex))
50+
end)
51+
end
52+
53+
Py_Type(x) = Base.GC.@preserve x @ft PyPtr(UnsafePtr{PyObject}(asptr(x)).type[!])
54+
1155
PyObject_Type(x) = Base.GC.@preserve x (t = Py_Type(asptr(x)); Py_IncRef(t); t)
1256

1357
Py_TypeCheck(o, t) = Base.GC.@preserve o t PyType_IsSubtype(Py_Type(asptr(o)), asptr(t))
@@ -16,13 +60,7 @@ Py_TypeCheckFast(o, f::Integer) = Base.GC.@preserve o PyType_IsSubtypeFast(Py_Ty
1660
PyType_IsSubtypeFast(t, f::Integer) =
1761
Base.GC.@preserve t Cint(!iszero(PyType_GetFlags(asptr(t)) & f))
1862

19-
PyMemoryView_GET_BUFFER(m) = Base.GC.@preserve m begin
20-
if CTX.is_free_threaded
21-
Ptr{Py_buffer}(UnsafePtr{PyMemoryViewObjectFT}(asptr(m)).view)
22-
else
23-
Ptr{Py_buffer}(UnsafePtr{PyMemoryViewObject}(asptr(m)).view)
24-
end
25-
end
63+
PyMemoryView_GET_BUFFER(m) = Base.GC.@preserve m @ft Ptr{Py_buffer}(UnsafePtr{PyMemoryViewObject}(asptr(m)).view)
2664

2765
PyType_CheckBuffer(t) = Base.GC.@preserve t begin
2866
getbuf = PyType_GetSlot(asptr(t), Py_bf_getbuffer)
@@ -79,13 +117,7 @@ function PyOS_RunInputHook()
79117
end
80118

81119
function PySimpleObject_GetValue(::Type{T}, o) where {T}
82-
Base.GC.@preserve o begin
83-
if CTX.is_free_threaded
84-
UnsafePtr{PySimpleObjectFT{T}}(asptr(o)).value[!]
85-
else
86-
UnsafePtr{PySimpleObject{T}}(asptr(o)).value[!]
87-
end
88-
end
120+
Base.GC.@preserve o @ft UnsafePtr{PySimpleObject{T}}(asptr(o)).value[!]
89121
end
90122

91123
# FAST REFCOUNTING

src/JlWrap/C.jl

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -26,35 +26,23 @@ const PYJLVALUES = []
2626
# unused indices in PYJLVALUES
2727
const PYJLFREEVALUES = Int[]
2828

29-
# Choose the layout based on whether Python is free-threaded.
30-
pyjl_obj_type() = C.CTX.is_free_threaded ? PyJuliaValueObjectFT : PyJuliaValueObject
31-
32-
function pyjl_obj_ptr(o::C.PyPtr)
33-
if C.CTX.is_free_threaded
34-
return UnsafePtr{PyJuliaValueObjectFT}(o)
35-
else
36-
return UnsafePtr{PyJuliaValueObject}(o)
37-
end
38-
end
39-
4029
function _pyjl_new(t::C.PyPtr, ::C.PyPtr, ::C.PyPtr)
4130
alloc = C.PyType_GetSlot(t, C.Py_tp_alloc)
4231
alloc == C_NULL && return C.PyNULL
4332
o = ccall(alloc, C.PyPtr, (C.PyPtr, C.Py_ssize_t), t, 0)
4433
o == C.PyNULL && return C.PyNULL
45-
p = pyjl_obj_ptr(o)
46-
p.weaklist[] = C.PyNULL
47-
p.value[] = 0
34+
C.@ft UnsafePtr{PyJuliaValueObject}(o).weaklist[] = C.PyNULL
35+
C.@ft UnsafePtr{PyJuliaValueObject}(o).value[] = 0
4836
return o
4937
end
5038

5139
function _pyjl_dealloc(o::C.PyPtr)
52-
idx = pyjl_obj_ptr(o).value[]
40+
idx = C.@ft UnsafePtr{PyJuliaValueObject}(o).value[]
5341
if idx != 0
5442
PYJLVALUES[idx] = nothing
5543
push!(PYJLFREEVALUES, idx)
5644
end
57-
pyjl_obj_ptr(o).weaklist[!] == C.PyNULL || C.PyObject_ClearWeakRefs(o)
45+
(C.@ft UnsafePtr{PyJuliaValueObject}(o).weaklist[!]) == C.PyNULL || C.PyObject_ClearWeakRefs(o)
5846
freeptr = C.PyType_GetSlot(C.Py_Type(o), C.Py_tp_free)
5947
freeptr == C_NULL || ccall(freeptr, Cvoid, (C.PyPtr,), o)
6048
nothing
@@ -335,13 +323,12 @@ function init_c()
335323

336324
# Create members for weakref support
337325
empty!(_pyjlbase_members)
338-
objT = pyjl_obj_type()
339326
push!(
340327
_pyjlbase_members,
341328
C.PyMemberDef(
342329
name = pointer(_pyjlbase_weaklistoffset_name),
343330
typ = C.Py_T_PYSSIZET,
344-
offset = fieldoffset(objT, 3),
331+
offset = (C.@ft fieldoffset(PyJuliaValueObject, 3)),
345332
flags = C.Py_READONLY,
346333
),
347334
C.PyMemberDef(), # NULL terminator
@@ -363,7 +350,7 @@ function init_c()
363350
# Create PyType_Spec
364351
_pyjlbase_spec[] = C.PyType_Spec(
365352
name = pointer(_pyjlbase_name),
366-
basicsize = sizeof(objT),
353+
basicsize = (C.@ft sizeof(PyJuliaValueObject)),
367354
flags = C.Py_TPFLAGS_BASETYPE | C.Py_TPFLAGS_HAVE_VERSION_TAG,
368355
slots = pointer(_pyjlbase_slots),
369356
)
@@ -380,14 +367,13 @@ function __init__()
380367
init_c()
381368
end
382369

383-
PyJuliaValue_IsNull(o) = Base.GC.@preserve o pyjl_obj_ptr(C.asptr(o)).value[] == 0
370+
PyJuliaValue_IsNull(o) = Base.GC.@preserve o (C.@ft UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[]) == 0
384371

385-
PyJuliaValue_GetValue(o) = Base.GC.@preserve o PYJLVALUES[pyjl_obj_ptr(C.asptr(o)).value[]]
372+
PyJuliaValue_GetValue(o) = Base.GC.@preserve o PYJLVALUES[(C.@ft UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[])]
386373

387374
PyJuliaValue_SetValue(_o, @nospecialize(v)) = Base.GC.@preserve _o begin
388375
o = C.asptr(_o)
389-
p = pyjl_obj_ptr(o)
390-
idx = p.value[]
376+
idx = C.@ft UnsafePtr{PyJuliaValueObject}(o).value[]
391377
if idx == 0
392378
if isempty(PYJLFREEVALUES)
393379
push!(PYJLVALUES, v)
@@ -396,7 +382,7 @@ PyJuliaValue_SetValue(_o, @nospecialize(v)) = Base.GC.@preserve _o begin
396382
idx = pop!(PYJLFREEVALUES)
397383
PYJLVALUES[idx] = v
398384
end
399-
p.value[] = idx
385+
C.@ft UnsafePtr{PyJuliaValueObject}(o).value[] = idx
400386
else
401387
PYJLVALUES[idx] = v
402388
end

0 commit comments

Comments
 (0)