Skip to content

Commit 29bb919

Browse files
[SPARK-57679][PYTHON] Fix numpy type checking
### What changes were proposed in this pull request? Add a branch to detect `NDArray` for `numpy >= 2.5.0`. ### Why are the changes needed? Our detection for `NDArray` is a bit fragile. `numpy` changed the type alias in `2.5.0` so we need another way to detect it. CI is failing - https://github.com/apache/spark/actions/runs/28064134730/job/83087583742 ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Local test passed ### Was this patch authored or co-authored using generative AI tooling? Yes, Claude Code (Opus 4.8 high) Closes #56757 from gaogaotiantian/fix-numpy-type-checking. Authored-by: Tian Gao <gaogaotiantian@hotmail.com> Signed-off-by: Tian Gao <gaogaotiantian@hotmail.com> (cherry picked from commit 5f7570c) Signed-off-by: Tian Gao <gaogaotiantian@hotmail.com>
1 parent 29caede commit 29bb919

1 file changed

Lines changed: 11 additions & 1 deletion

File tree

python/pyspark/pandas/typedef/typehints.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,8 +155,18 @@ def as_spark_type(
155155
and hasattr(tpe, "__args__")
156156
and len(tpe.__args__) > 1
157157
):
158-
# numpy.typing.NDArray
158+
# numpy.typing.NDArray for numpy < 2.5
159159
return types.ArrayType(as_spark_type(tpe.__args__[1].__args__[0], raise_error=raise_error))
160+
elif (
161+
hasattr(tpe, "__origin__")
162+
and hasattr(tpe.__origin__, "__value__")
163+
and getattr(tpe.__origin__.__value__, "__origin__", None) is np.ndarray
164+
and hasattr(tpe, "__args__")
165+
and len(tpe.__args__) > 0
166+
):
167+
# numpy.typing.NDArray for numpy >= 2.5: a PEP 695 type alias whose __value__
168+
# resolves to np.ndarray[shape, dtype[scalar]], with the scalar at __args__[0]
169+
return types.ArrayType(as_spark_type(tpe.__args__[0], raise_error=raise_error))
160170

161171
if isinstance(tpe, np.dtype) and tpe == np.dtype("object"):
162172
pass

0 commit comments

Comments
 (0)