Skip to content

Commit 2e0bea9

Browse files
committed
fix: route PyDict REPL completion through Python thread
1 parent beec1ec commit 2e0bea9

6 files changed

Lines changed: 47 additions & 8 deletions

File tree

src/Core/Core.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ const ROOT_DIR = dirname(dirname(@__DIR__))
1111
using ..PythonCall
1212
using ..C
1313
using ..GC: GC
14+
using ..GIL
1415
using ..Utils
1516

1617
using Base: @propagate_inbounds, @kwdef

src/Core/Py.jl

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,27 @@ Py(x::Date) = pydate(x)
125125
Py(x::Time) = pytime(x)
126126
Py(x::DateTime) = pydatetime(x)
127127

128+
function _py_with_gil_or_on_main_thread(f)
129+
if C.PyGILState_Check() == 1
130+
f()
131+
elseif C.on_main_thread(Threads.threadid)::Int == Threads.threadid()
132+
GIL.lock(f)
133+
else
134+
C.on_main_thread(f)
135+
end
136+
end
137+
128138
Base.string(x::Py) = pyisnull(x) ? "<py NULL>" : pystr(String, x)
129139
Base.print(io::IO, x::Py) = print(io, string(x))
130140

131141
function Base.show(io::IO, x::Py)
142+
_py_with_gil_or_on_main_thread() do
143+
_show(io, x)
144+
nothing
145+
end
146+
end
147+
148+
function _show(io::IO, x::Py)
132149
if get(io, :typeinfo, Any) == Py
133150
if pyisnull(x)
134151
print(io, "NULL")
@@ -292,13 +309,9 @@ function _propertynames(x::Py, private::Bool)
292309
end
293310

294311
function Base.propertynames(x::Py, private::Bool = false)
295-
if C.PyGILState_Check() == 1
312+
_py_with_gil_or_on_main_thread() do
296313
_propertynames(x, private)
297-
else
298-
C.on_main_thread() do
299-
_propertynames(x, private)
300-
end::Vector{Symbol}
301-
end
314+
end::Vector{Symbol}
302315
end
303316

304317
Base.Bool(x::Py) = pytruth(x)

src/Wrap/PyDict.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,19 @@ function Base.iterate(x::PyDict{K,V}, it::Py = pyiter(x)) where {K,V}
2424
return (k => v, it)
2525
end
2626

27-
function Base.iterate(x::Base.KeySet{K,PyDict{K,V}}, it::Py = pyiter(x.dict)) where {K,V}
27+
function Base.iterate(x::Base.KeySet{K,<:PyDict{K}})::Union{Nothing,Tuple{K,Py}} where {K}
28+
_py_with_gil_or_on_main_thread() do
29+
_iterate(x, pyiter(x.dict))
30+
end
31+
end
32+
33+
function Base.iterate(x::Base.KeySet{K,<:PyDict{K}}, it::Py)::Union{Nothing,Tuple{K,Py}} where {K}
34+
_py_with_gil_or_on_main_thread() do
35+
_iterate(x, it)
36+
end
37+
end
38+
39+
function _iterate(x::Base.KeySet{K,<:PyDict{K}}, it::Py) where {K}
2840
k_ = unsafe_pynext(it)
2941
pyisnull(k_) && return nothing
3042
k = pyconvert(K, k_)

src/Wrap/Wrap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ using Base: @propagate_inbounds
2020
using Tables: Tables
2121
using UnsafePointers: UnsafePtr
2222

23-
import ..Core: Py, ispy
23+
import ..Core: Py, ispy, _py_with_gil_or_on_main_thread
2424

2525
include("PyIterable.jl")
2626
include("PyDict.jl")

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
55
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
66
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
77
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
8+
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
89
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
910
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
1011
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

test/Wrap.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,18 @@ end
122122
@testset "iterate keys" begin
123123
@test collect(keys(z)) == ["foo"]
124124
end
125+
@testset "complete keys without GIL" begin
126+
using REPL
127+
completion_count = PythonCall.GIL.@unlock begin
128+
task = @async begin
129+
completions, _, _ = REPL.REPLCompletions.completions("y[", 2, @__MODULE__)
130+
length(completions)
131+
end
132+
wait(task)
133+
fetch(task)
134+
end
135+
@test completion_count == 1
136+
end
125137
@testset "getindex" begin
126138
@test z["foo"] === 12
127139
@test_throws KeyError z["bar"]

0 commit comments

Comments
 (0)