Skip to content

Commit 13fe8a2

Browse files
authored
Fix pyarrow import (#1316)
`neo4j.vector.Vector` uses `pyarrow.compute.count` but the driver never loads the module `pyarrow.compute`. We've not seen issues in tests because we either have all optional dependencies or none installed. And the current versions of `pandas` (another optional dependency), when loaded, in turn loads `pyarrow.compute`. However, this is not guaranteed to remain this way and won't work for users only that only have pyarrow (but not pandas) installed. Further does this PR silence pyarrow's type annotations for now as they're not quite up to snuff yet.
1 parent 8f2b069 commit 13fe8a2

4 files changed

Lines changed: 16 additions & 3 deletions

File tree

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,13 @@ module = [
225225
]
226226
ignore_missing_imports = true
227227

228+
[[tool.mypy.overrides]]
229+
# https://github.com/apache/arrow/issues/50123
230+
module = [
231+
"pyarrow.*",
232+
]
233+
follow_imports = "skip"
234+
228235
[tool.ruff]
229236
line-length = 79
230237
extend-exclude = [

src/neo4j/_optional_deps.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,16 @@
3232
import pandas as pd # type: ignore[no-redef]
3333

3434
pa: t.Any = None
35+
pa_compute: t.Any = None
3536

3637
with suppress(ImportError):
3738
import pyarrow as pa # type: ignore[no-redef]
39+
import pyarrow.compute as pa_compute # type: ignore[no-redef]
3840

3941

4042
__all__ = [
4143
"np",
4244
"pa",
45+
"pa_compute",
4346
"pd",
4447
]

src/neo4j/vector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@
3232
# This beautiful construct helps sphinx to properly resolve the type hints.
3333
import numpy as _np
3434
import pyarrow as _pa
35+
import pyarrow.compute as _pa_compute
3536
else:
3637
from ._optional_deps import (
3738
np as _np,
3839
pa as _pa,
40+
pa_compute as _pa_compute,
3941
)
4042

4143

@@ -795,7 +797,7 @@ def to_numpy(self) -> _np.ndarray: ...
795797
def from_pyarrow(cls, data: _pa.Array, /) -> _t.Self:
796798
width = data.type.byte_width
797799
assert cls.size == width
798-
if _pa.compute.count(data, mode="only_null").as_py():
800+
if _pa_compute.count(data, mode="only_null").as_py():
799801
raise ValueError("PyArrow array must not contain any null values.")
800802
_, buffer = data.buffers()
801803
buffer = buffer[

tests/unit/common/vector/test_vector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from neo4j._optional_deps import (
2929
np,
3030
pa,
31+
pa_compute,
3132
)
3233
from neo4j.vector import (
3334
_swap_endian,
@@ -1044,7 +1045,7 @@ def test_to_pyarrow_random(
10441045
v = _vector_from_data(data_be, dtype, endian)
10451046
array = v.to_pyarrow()
10461047
assert array.type == pa_type
1047-
assert pa.compute.count(array, mode="only_null").as_py() == 0
1048+
assert pa_compute.count(array, mode="only_null").as_py() == 0
10481049
buffers = array.buffers()
10491050
assert len(buffers) == 2
10501051
assert buffers[0] is None
@@ -1076,7 +1077,7 @@ def test_to_pyarrow_special_values(
10761077
v = _vector_from_data(data_be, dtype, endian)
10771078
array = v.to_pyarrow()
10781079
assert array.type == pa_type
1079-
assert pa.compute.count(array, mode="only_null").as_py() == 0
1080+
assert pa_compute.count(array, mode="only_null").as_py() == 0
10801081
buffers = array.buffers()
10811082
assert len(buffers) == 2
10821083
assert buffers[0] is None

0 commit comments

Comments
 (0)