Skip to content

Commit 5a803f6

Browse files
[SPARK-57679][PYTHON] Fix numpy type checking
Add a branch to detect `NDArray` for `numpy >= 2.5.0`. Our detection for `NDArray` is a bit fragile. `numpy` changed the type alias in `2.5.0` so we need another way to detect it. CI is failing - https://github.com/apache/spark/actions/runs/28064134730/job/83087583742 No. Local test passed Yes, Claude Code (Opus 4.8 high) Closes #56757 from gaogaotiantian/fix-numpy-type-checking. Authored-by: Tian Gao <gaogaotiantian@hotmail.com> Signed-off-by: Tian Gao <gaogaotiantian@hotmail.com> (cherry picked from commit 5f7570c) Signed-off-by: Tian Gao <gaogaotiantian@hotmail.com>
1 parent 953e022 commit 5a803f6

1 file changed

Lines changed: 11 additions & 1 deletion

File tree

python/pyspark/pandas/typedef/typehints.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,20 @@ def as_spark_type(
158158
and hasattr(tpe, "__args__")
159159
and len(tpe.__args__) > 1
160160
):
161-
# numpy.typing.NDArray
161+
# numpy.typing.NDArray for numpy < 2.5
162162
return types.ArrayType(
163163
as_spark_type(tpe.__args__[1].__args__[0], raise_error=raise_error)
164164
)
165+
elif (
166+
hasattr(tpe, "__origin__")
167+
and hasattr(tpe.__origin__, "__value__")
168+
and getattr(tpe.__origin__.__value__, "__origin__", None) is np.ndarray
169+
and hasattr(tpe, "__args__")
170+
and len(tpe.__args__) > 0
171+
):
172+
# numpy.typing.NDArray for numpy >= 2.5: a PEP 695 type alias whose __value__
173+
# resolves to np.ndarray[shape, dtype[scalar]], with the scalar at __args__[0]
174+
return types.ArrayType(as_spark_type(tpe.__args__[0], raise_error=raise_error))
165175

166176
if isinstance(tpe, np.dtype) and tpe == np.dtype("object"):
167177
pass

0 commit comments

Comments
 (0)