Skip to content

Commit 946f6a8

Browse files
committed
Move numpy dtype mapping to Julia
1 parent fa80da6 commit 946f6a8

File tree

4 files changed

+98
-64
lines changed

4 files changed

+98
-64
lines changed

docs/src/juliacall-reference.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ from juliacall import Main as jl
200200
# equivalent to Vector{Int}() in Julia
201201
jl.Vector[jl.Int]()
202202
```
203+
204+
If NumPy is available, primitive types expose a `__numpy_dtype__` property that returns the
205+
corresponding `numpy.dtype` (e.g. `jl.Int64.__numpy_dtype__`). Unsupported types raise
206+
`AttributeError`.
203207
`````
204208

205209
`````@customdoc

pytest/test_all.py

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,33 +25,6 @@ def test_convert():
2525
assert jl.isa(y, t)
2626

2727

28-
def test_typevalue_numpy_dtype():
29-
import numpy as np
30-
from juliacall import Base as jl
31-
32-
assert jl.Bool.__numpy_dtype__ == np.dtype(np.bool_)
33-
assert jl.Int8.__numpy_dtype__ == np.dtype(np.int8)
34-
assert jl.Int16.__numpy_dtype__ == np.dtype(np.int16)
35-
assert jl.Int32.__numpy_dtype__ == np.dtype(np.int32)
36-
assert jl.Int64.__numpy_dtype__ == np.dtype(np.int64)
37-
assert jl.Int.__numpy_dtype__ == np.dtype(np.int_)
38-
assert jl.UInt8.__numpy_dtype__ == np.dtype(np.uint8)
39-
assert jl.UInt16.__numpy_dtype__ == np.dtype(np.uint16)
40-
assert jl.UInt32.__numpy_dtype__ == np.dtype(np.uint32)
41-
assert jl.UInt64.__numpy_dtype__ == np.dtype(np.uint64)
42-
assert jl.UInt.__numpy_dtype__ == np.dtype(np.uintp)
43-
assert jl.Float16.__numpy_dtype__ == np.dtype(np.float16)
44-
assert jl.Float32.__numpy_dtype__ == np.dtype(np.float32)
45-
assert jl.Float64.__numpy_dtype__ == np.dtype(np.float64)
46-
assert jl.ComplexF32.__numpy_dtype__ == np.dtype(np.complex64)
47-
assert jl.ComplexF64.__numpy_dtype__ == np.dtype(np.complex128)
48-
assert jl.Ptr[jl.Cvoid].__numpy_dtype__ == np.dtype("P")
49-
with pytest.raises(AttributeError):
50-
_ = jl.ComplexF16.__numpy_dtype__
51-
with pytest.raises(AttributeError):
52-
_ = jl.String.__numpy_dtype__
53-
54-
5528
def test_interactive():
5629
import juliacall
5730

src/JlWrap/type.jl

Lines changed: 50 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,55 @@ function pyjltype_getitem(self::Type, k_)
1111
end
1212
end
1313

14+
function pyjltype_numpy_dtype(self::Type)
15+
np = pyimport("numpy")
16+
if self === Bool
17+
return np.dtype(np.bool_)
18+
elseif self === Int8
19+
return np.dtype(np.int8)
20+
elseif self === Int16
21+
return np.dtype(np.int16)
22+
elseif self === Int32
23+
return np.dtype(np.int32)
24+
elseif self === Int64
25+
return np.dtype(np.int64)
26+
elseif self === UInt8
27+
return np.dtype(np.uint8)
28+
elseif self === UInt16
29+
return np.dtype(np.uint16)
30+
elseif self === UInt32
31+
return np.dtype(np.uint32)
32+
elseif self === UInt64
33+
return np.dtype(np.uint64)
34+
elseif self === Float16
35+
return np.dtype(np.float16)
36+
elseif self === Float32
37+
return np.dtype(np.float32)
38+
elseif self === Float64
39+
return np.dtype(np.float64)
40+
elseif self === ComplexF32
41+
return np.dtype(np.complex64)
42+
elseif self === ComplexF64
43+
return np.dtype(np.complex128)
44+
elseif self === Ptr{Cvoid}
45+
return np.dtype("P")
46+
end
47+
@static if Int !== Int64
48+
if self === Int
49+
return np.dtype(np.int_)
50+
end
51+
end
52+
@static if UInt !== UInt64
53+
if self === UInt
54+
return np.dtype(np.uintp)
55+
end
56+
end
57+
errset(pybuiltins.AttributeError, "__numpy_dtype__")
58+
return PyNULL
59+
end
60+
61+
pyjl_handle_error_type(::typeof(pyjltype_numpy_dtype), x, exc) = pybuiltins.AttributeError
62+
1463
function init_type()
1564
jl = pyjuliacallmodule
1665
pybuiltins.exec(
@@ -27,42 +76,7 @@ class TypeValue(AnyValue):
2776
raise TypeError("not supported")
2877
@property
2978
def __numpy_dtype__(self):
30-
import numpy
31-
if self == Base.Bool:
32-
return numpy.dtype(numpy.bool_)
33-
if self == Base.Int8:
34-
return numpy.dtype(numpy.int8)
35-
if self == Base.Int16:
36-
return numpy.dtype(numpy.int16)
37-
if self == Base.Int32:
38-
return numpy.dtype(numpy.int32)
39-
if self == Base.Int64:
40-
return numpy.dtype(numpy.int64)
41-
if self == Base.Int:
42-
return numpy.dtype(numpy.int_)
43-
if self == Base.UInt8:
44-
return numpy.dtype(numpy.uint8)
45-
if self == Base.UInt16:
46-
return numpy.dtype(numpy.uint16)
47-
if self == Base.UInt32:
48-
return numpy.dtype(numpy.uint32)
49-
if self == Base.UInt64:
50-
return numpy.dtype(numpy.uint64)
51-
if self == Base.UInt:
52-
return numpy.dtype(numpy.uintp)
53-
if self == Base.Float16:
54-
return numpy.dtype(numpy.float16)
55-
if self == Base.Float32:
56-
return numpy.dtype(numpy.float32)
57-
if self == Base.Float64:
58-
return numpy.dtype(numpy.float64)
59-
if self == Base.ComplexF32:
60-
return numpy.dtype(numpy.complex64)
61-
if self == Base.ComplexF64:
62-
return numpy.dtype(numpy.complex128)
63-
if self == Base.Ptr[Base.Cvoid]:
64-
return numpy.dtype("P")
65-
raise AttributeError("__numpy_dtype__")
79+
return self._jl_callmethod($(pyjl_methodnum(pyjltype_numpy_dtype)))
6680
""",
6781
@__FILE__(),
6882
"exec",

test/JlWrap.jl

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,13 +472,56 @@ end
472472
end
473473
end
474474

475-
@testitem "type" begin
475+
@testitem "type" setup=[Setup] begin
476476
@testset "type" begin
477477
@test pyis(pytype(pyjl(Int)), PythonCall.pyjltypetype)
478478
end
479479
@testset "bool" begin
480480
@test pytruth(pyjl(Int))
481481
end
482+
@testset "numpy dtype" begin
483+
if Setup.devdeps
484+
np = pyimport("numpy")
485+
@test pyeq(Bool, pygetattr(pyjl(Bool), "__numpy_dtype__"), np.dtype(np.bool_))
486+
@test pyeq(Bool, pygetattr(pyjl(Int8), "__numpy_dtype__"), np.dtype(np.int8))
487+
@test pyeq(Bool, pygetattr(pyjl(Int16), "__numpy_dtype__"), np.dtype(np.int16))
488+
@test pyeq(Bool, pygetattr(pyjl(Int32), "__numpy_dtype__"), np.dtype(np.int32))
489+
@test pyeq(Bool, pygetattr(pyjl(Int64), "__numpy_dtype__"), np.dtype(np.int64))
490+
@test pyeq(Bool, pygetattr(pyjl(Int), "__numpy_dtype__"), np.dtype(np.int_))
491+
@test pyeq(Bool, pygetattr(pyjl(UInt8), "__numpy_dtype__"), np.dtype(np.uint8))
492+
@test pyeq(Bool, pygetattr(pyjl(UInt16), "__numpy_dtype__"), np.dtype(np.uint16))
493+
@test pyeq(Bool, pygetattr(pyjl(UInt32), "__numpy_dtype__"), np.dtype(np.uint32))
494+
@test pyeq(Bool, pygetattr(pyjl(UInt64), "__numpy_dtype__"), np.dtype(np.uint64))
495+
@test pyeq(Bool, pygetattr(pyjl(UInt), "__numpy_dtype__"), np.dtype(np.uintp))
496+
@test pyeq(Bool, pygetattr(pyjl(Float16), "__numpy_dtype__"), np.dtype(np.float16))
497+
@test pyeq(Bool, pygetattr(pyjl(Float32), "__numpy_dtype__"), np.dtype(np.float32))
498+
@test pyeq(Bool, pygetattr(pyjl(Float64), "__numpy_dtype__"), np.dtype(np.float64))
499+
@test pyeq(Bool, pygetattr(pyjl(ComplexF32), "__numpy_dtype__"), np.dtype(np.complex64))
500+
@test pyeq(Bool, pygetattr(pyjl(ComplexF64), "__numpy_dtype__"), np.dtype(np.complex128))
501+
@test pyeq(Bool, pygetattr(pyjl(Ptr{Cvoid}), "__numpy_dtype__"), np.dtype("P"))
502+
@test pyeq(Bool, np.dtype(pyjl(Int64)), np.dtype(np.int64))
503+
504+
err = try
505+
pygetattr(pyjl(ComplexF16), "__numpy_dtype__")
506+
nothing
507+
catch err
508+
err
509+
end
510+
@test err isa PythonCall.PyException
511+
@test pyis(err._t, pybuiltins.AttributeError)
512+
513+
err = try
514+
pygetattr(pyjl(String), "__numpy_dtype__")
515+
nothing
516+
catch err
517+
err
518+
end
519+
@test err isa PythonCall.PyException
520+
@test pyis(err._t, pybuiltins.AttributeError)
521+
else
522+
@test_skip Setup.devdeps
523+
end
524+
end
482525
end
483526

484527
@testitem "vector" begin

0 commit comments

Comments
 (0)