Skip to content

Commit 5783d27

Browse files
fix: support CPython free-threaded builds
Co-authored-by: Miles Cranmer <miles.cranmer@gmail.com>
1 parent ef6c404 commit 5783d27

File tree

5 files changed

+110
-29
lines changed

5 files changed

+110
-29
lines changed

src/C/consts.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,20 @@ end
124124
type::Ptr{Cvoid} = C_NULL # really is Ptr{PyObject} or Ptr{PyTypeObject} but Julia 1.3 and below get the layout incorrect when circular types are involved
125125
end
126126

127+
@kwdef struct PyMutex
128+
bits::Cuchar = 0
129+
end
130+
131+
@kwdef struct PyObjectFT
132+
tid::Csize_t = 0
133+
flags::Cushort = 0
134+
mutex::PyMutex = PyMutex()
135+
gc_bits::Cuchar = 0
136+
ref_local::Cuint = 0
137+
ref_shared::Py_ssize_t = 0
138+
type::Ptr{Cvoid} = C_NULL # really is Ptr{PyObject} or Ptr{PyTypeObject} but Julia 1.3 and below get the layout incorrect when circular types are involved
139+
end
140+
127141
const PyPtr = Ptr{PyObject}
128142
const PyNULL = PyPtr(0)
129143

@@ -139,6 +153,11 @@ Base.unsafe_convert(::Type{PyPtr}, o::PyObjectRef) = o.ptr
139153
size::Py_ssize_t = 0
140154
end
141155

156+
@kwdef struct PyVarObjectFT
157+
ob_base::PyObjectFT = PyObjectFT()
158+
size::Py_ssize_t = 0
159+
end
160+
142161
@kwdef struct PyMethodDef
143162
name::Cstring = C_NULL
144163
meth::Ptr{Cvoid} = C_NULL
@@ -249,6 +268,16 @@ end
249268
weakreflist::PyPtr = PyNULL
250269
end
251270

271+
@kwdef struct PyMemoryViewObjectFT
272+
ob_base::PyVarObjectFT = PyVarObjectFT()
273+
mbuf::PyPtr = PyNULL
274+
hash::Py_hash_t = 0
275+
flags::Cint = 0
276+
exports::Py_ssize_t = 0
277+
view::Py_buffer = Py_buffer()
278+
weakreflist::PyPtr = PyNULL
279+
end
280+
252281
@kwdef struct PyTypeObject
253282
ob_base::PyVarObject = PyVarObject()
254283
name::Cstring = C_NULL
@@ -327,6 +356,11 @@ end
327356
value::T
328357
end
329358

359+
@kwdef struct PySimpleObjectFT{T}
360+
ob_base::PyObjectFT = PyObjectFT()
361+
value::T
362+
end
363+
330364
@kwdef struct PyArrayInterface
331365
two::Cint = 0
332366
nd::Cint = 0

src/C/context.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ A handle to a loaded instance of libpython, its interpreter, function pointers,
1717
pyhome_w::Any = missing
1818
which::Symbol = :unknown # :CondaPkg, :PyCall, :embedded or :unknown
1919
version::Union{VersionNumber,Missing} = missing
20+
is_free_threaded::Bool = false
2021
end
2122

2223
const CTX = Context()
@@ -312,10 +313,11 @@ function init_context()
312313
v"3.10" CTX.version < v"4" || error(
313314
"Only Python 3.10+ is supported, this is Python $(CTX.version) at $(CTX.exe_path===missing ? "unknown location" : CTX.exe_path).",
314315
)
316+
CTX.is_free_threaded = occursin("free-threading build", verstr)
315317

316318
launch_on_main_thread(Threads.threadid()) # makes on_main_thread usable
317319

318-
@debug "Initialized PythonCall.jl" CTX.is_embedded CTX.is_initialized CTX.exe_path CTX.lib_path CTX.lib_ptr CTX.pyprogname CTX.pyhome CTX.version
320+
@debug "Initialized PythonCall.jl" CTX.is_embedded CTX.is_initialized CTX.exe_path CTX.lib_path CTX.lib_ptr CTX.pyprogname CTX.pyhome CTX.version CTX.is_free_threaded
319321

320322
return
321323
end

src/C/extras.jl

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,58 @@
11
asptr(x) = Base.unsafe_convert(PyPtr, x)
22

3-
Py_Type(x) = Base.GC.@preserve x PyPtr(UnsafePtr(asptr(x)).type[!])
3+
Py_Type(x) = Base.GC.@preserve x begin
4+
if CTX.is_free_threaded
5+
PyPtr(UnsafePtr{PyObjectFT}(asptr(x)).type[!])
6+
else
7+
PyPtr(UnsafePtr{PyObject}(asptr(x)).type[!])
8+
end
9+
end
410

511
PyObject_Type(x) = Base.GC.@preserve x (t = Py_Type(asptr(x)); Py_IncRef(t); t)
612

713
Py_TypeCheck(o, t) = Base.GC.@preserve o t PyType_IsSubtype(Py_Type(asptr(o)), asptr(t))
814
Py_TypeCheckFast(o, f::Integer) = Base.GC.@preserve o PyType_IsSubtypeFast(Py_Type(asptr(o)), f)
915

1016
PyType_IsSubtypeFast(t, f::Integer) =
11-
Base.GC.@preserve t Cint(!iszero(UnsafePtr{PyTypeObject}(asptr(t)).flags[] & f))
17+
Base.GC.@preserve t Cint(!iszero(PyType_GetFlags(asptr(t)) & f))
1218

13-
PyMemoryView_GET_BUFFER(m) = Base.GC.@preserve m Ptr{Py_buffer}(UnsafePtr{PyMemoryViewObject}(asptr(m)).view)
19+
PyMemoryView_GET_BUFFER(m) = Base.GC.@preserve m begin
20+
if CTX.is_free_threaded
21+
Ptr{Py_buffer}(UnsafePtr{PyMemoryViewObjectFT}(asptr(m)).view)
22+
else
23+
Ptr{Py_buffer}(UnsafePtr{PyMemoryViewObject}(asptr(m)).view)
24+
end
25+
end
1426

1527
PyType_CheckBuffer(t) = Base.GC.@preserve t begin
16-
p = UnsafePtr{PyTypeObject}(asptr(t)).as_buffer[]
17-
return p != C_NULL && p.get[!] != C_NULL
28+
getbuf = PyType_GetSlot(asptr(t), Py_bf_getbuffer)
29+
return getbuf != C_NULL
1830
end
1931

2032
PyObject_CheckBuffer(o) = Base.GC.@preserve o PyType_CheckBuffer(Py_Type(asptr(o)))
2133

2234
PyObject_GetBuffer(_o, b, flags) = Base.GC.@preserve _o begin
2335
o = asptr(_o)
24-
p = UnsafePtr{PyTypeObject}(Py_Type(o)).as_buffer[]
25-
if p == C_NULL || p.get[!] == C_NULL
26-
PyErr_SetString(
27-
POINTERS.PyExc_TypeError,
28-
"a bytes-like object is required, not '$(String(UnsafePtr{PyTypeObject}(Py_Type(o)).name[]))'",
29-
)
36+
getbuf = PyType_GetSlot(Py_Type(o), Py_bf_getbuffer)
37+
if getbuf == C_NULL
38+
msg = if CTX.is_free_threaded
39+
"a bytes-like object is required"
40+
else
41+
"a bytes-like object is required, not '$(String(UnsafePtr{PyTypeObject}(Py_Type(o)).name[]))'"
42+
end
43+
PyErr_SetString(POINTERS.PyExc_TypeError, msg)
3044
return Cint(-1)
3145
end
32-
return ccall(p.get[!], Cint, (PyPtr, Ptr{Py_buffer}, Cint), o, b, flags)
46+
return ccall(getbuf, Cint, (PyPtr, Ptr{Py_buffer}, Cint), o, b, flags)
3347
end
3448

3549
PyBuffer_Release(_b) = begin
3650
b = UnsafePtr(Base.unsafe_convert(Ptr{Py_buffer}, _b))
3751
o = b.obj[]
3852
o == C_NULL && return
39-
p = UnsafePtr{PyTypeObject}(Py_Type(o)).as_buffer[]
40-
if (p != C_NULL && p.release[!] != C_NULL)
41-
ccall(p.release[!], Cvoid, (PyPtr, Ptr{Py_buffer}), o, b)
53+
releasebuf = PyType_GetSlot(Py_Type(o), Py_bf_releasebuffer)
54+
if releasebuf != C_NULL
55+
ccall(releasebuf, Cvoid, (PyPtr, Ptr{Py_buffer}), o, b)
4256
end
4357
b.obj[] = C_NULL
4458
Py_DecRef(o)
@@ -65,7 +79,13 @@ function PyOS_RunInputHook()
6579
end
6680

6781
function PySimpleObject_GetValue(::Type{T}, o) where {T}
68-
Base.GC.@preserve o UnsafePtr{PySimpleObject{T}}(asptr(o)).value[!]
82+
Base.GC.@preserve o begin
83+
if CTX.is_free_threaded
84+
UnsafePtr{PySimpleObjectFT{T}}(asptr(o)).value[!]
85+
else
86+
UnsafePtr{PySimpleObject{T}}(asptr(o)).value[!]
87+
end
88+
end
6989
end
7090

7191
# FAST REFCOUNTING

src/C/pointers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ const CAPI_FUNC_SIGS = Dict{Symbol,Pair{Tuple,Type}}(
7777
:PyType_Ready => (PyPtr,) => Cint,
7878
:PyType_GenericNew => (PyPtr, PyPtr, PyPtr) => PyPtr,
7979
:PyType_FromSpec => (Ptr{Cvoid},) => PyPtr,
80+
:PyType_GetFlags => (PyPtr,) => Culong,
81+
:PyType_GetSlot => (PyPtr, Cint) => Ptr{Cvoid},
8082
# MAPPING
8183
:PyMapping_HasKeyString => (PyPtr, Ptr{Cchar}) => Cint,
8284
:PyMapping_SetItemString => (PyPtr, Ptr{Cchar}, PyPtr) => Cint,

src/JlWrap/C.jl

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ using Serialization: serialize, deserialize
1212
weaklist::C.PyPtr = C_NULL
1313
end
1414

15+
@kwdef struct PyJuliaValueObjectFT
16+
ob_base::C.PyObjectFT = C.PyObjectFT()
17+
value::Int = 0
18+
weaklist::C.PyPtr = C_NULL
19+
end
20+
1521
const PyJuliaBase_Type = Ref(C.PyNULL)
1622

1723
# we store the actual julia values here
@@ -20,22 +26,37 @@ const PYJLVALUES = []
2026
# unused indices in PYJLVALUES
2127
const PYJLFREEVALUES = Int[]
2228

29+
# Choose the layout based on whether Python is free-threaded.
30+
pyjl_obj_type() = C.CTX.is_free_threaded ? PyJuliaValueObjectFT : PyJuliaValueObject
31+
32+
function pyjl_obj_ptr(o::C.PyPtr)
33+
if C.CTX.is_free_threaded
34+
return UnsafePtr{PyJuliaValueObjectFT}(o)
35+
else
36+
return UnsafePtr{PyJuliaValueObject}(o)
37+
end
38+
end
39+
2340
function _pyjl_new(t::C.PyPtr, ::C.PyPtr, ::C.PyPtr)
24-
o = ccall(UnsafePtr{C.PyTypeObject}(t).alloc[!], C.PyPtr, (C.PyPtr, C.Py_ssize_t), t, 0)
41+
alloc = C.PyType_GetSlot(t, C.Py_tp_alloc)
42+
alloc == C_NULL && return C.PyNULL
43+
o = ccall(alloc, C.PyPtr, (C.PyPtr, C.Py_ssize_t), t, 0)
2544
o == C.PyNULL && return C.PyNULL
26-
UnsafePtr{PyJuliaValueObject}(o).weaklist[] = C.PyNULL
27-
UnsafePtr{PyJuliaValueObject}(o).value[] = 0
45+
p = pyjl_obj_ptr(o)
46+
p.weaklist[] = C.PyNULL
47+
p.value[] = 0
2848
return o
2949
end
3050

3151
function _pyjl_dealloc(o::C.PyPtr)
32-
idx = UnsafePtr{PyJuliaValueObject}(o).value[]
52+
idx = pyjl_obj_ptr(o).value[]
3353
if idx != 0
3454
PYJLVALUES[idx] = nothing
3555
push!(PYJLFREEVALUES, idx)
3656
end
37-
UnsafePtr{PyJuliaValueObject}(o).weaklist[!] == C.PyNULL || C.PyObject_ClearWeakRefs(o)
38-
ccall(UnsafePtr{C.PyTypeObject}(C.Py_Type(o)).free[!], Cvoid, (C.PyPtr,), o)
57+
pyjl_obj_ptr(o).weaklist[!] == C.PyNULL || C.PyObject_ClearWeakRefs(o)
58+
freeptr = C.PyType_GetSlot(C.Py_Type(o), C.Py_tp_free)
59+
freeptr == C_NULL || ccall(freeptr, Cvoid, (C.PyPtr,), o)
3960
nothing
4061
end
4162

@@ -314,12 +335,13 @@ function init_c()
314335

315336
# Create members for weakref support
316337
empty!(_pyjlbase_members)
338+
objT = pyjl_obj_type()
317339
push!(
318340
_pyjlbase_members,
319341
C.PyMemberDef(
320342
name = pointer(_pyjlbase_weaklistoffset_name),
321343
typ = C.Py_T_PYSSIZET,
322-
offset = fieldoffset(PyJuliaValueObject, 3),
344+
offset = fieldoffset(objT, 3),
323345
flags = C.Py_READONLY,
324346
),
325347
C.PyMemberDef(), # NULL terminator
@@ -341,7 +363,7 @@ function init_c()
341363
# Create PyType_Spec
342364
_pyjlbase_spec[] = C.PyType_Spec(
343365
name = pointer(_pyjlbase_name),
344-
basicsize = sizeof(PyJuliaValueObject),
366+
basicsize = sizeof(objT),
345367
flags = C.Py_TPFLAGS_BASETYPE | C.Py_TPFLAGS_HAVE_VERSION_TAG,
346368
slots = pointer(_pyjlbase_slots),
347369
)
@@ -358,13 +380,14 @@ function __init__()
358380
init_c()
359381
end
360382

361-
PyJuliaValue_IsNull(o) = Base.GC.@preserve o UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[] == 0
383+
PyJuliaValue_IsNull(o) = Base.GC.@preserve o pyjl_obj_ptr(C.asptr(o)).value[] == 0
362384

363-
PyJuliaValue_GetValue(o) = Base.GC.@preserve o PYJLVALUES[UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[]]
385+
PyJuliaValue_GetValue(o) = Base.GC.@preserve o PYJLVALUES[pyjl_obj_ptr(C.asptr(o)).value[]]
364386

365387
PyJuliaValue_SetValue(_o, @nospecialize(v)) = Base.GC.@preserve _o begin
366388
o = C.asptr(_o)
367-
idx = UnsafePtr{PyJuliaValueObject}(o).value[]
389+
p = pyjl_obj_ptr(o)
390+
idx = p.value[]
368391
if idx == 0
369392
if isempty(PYJLFREEVALUES)
370393
push!(PYJLVALUES, v)
@@ -373,7 +396,7 @@ PyJuliaValue_SetValue(_o, @nospecialize(v)) = Base.GC.@preserve _o begin
373396
idx = pop!(PYJLFREEVALUES)
374397
PYJLVALUES[idx] = v
375398
end
376-
UnsafePtr{PyJuliaValueObject}(o).value[] = idx
399+
p.value[] = idx
377400
else
378401
PYJLVALUES[idx] = v
379402
end

0 commit comments

Comments
 (0)