Skip to content

Commit 8debd72

Browse files
mjwillsonvoctav
authored andcommitted
Make xarray_jax.JaxArrayWrapper fully compliant with the xarray.namedarray._typing._array_function protocol from more recent versions of xarray (see https://github.com/pydata/xarray/blob/693f0b91b4381f5a672cb93ff8113abd1dc4957c/xarray/namedarray/_typing.py#L114).
This is required for more recent versions of xarray to recognise it as a duck-typed array. Otherwise xarray will try to convert it to a numpy array in some situations, in particular reductions which now go via NamedArray. This causes problems if done to a jax tracer. The only change required to fulfill the protocol was to add `real` and `imag` properties to JaxArrayWrapper. PiperOrigin-RevId: 595985462 Change-Id: I4c844b5fd7d7787f0e35cc210979d679ae821044
1 parent 96de917 commit 8debd72

1 file changed

Lines changed: 15 additions & 4 deletions

File tree

graphcast/xarray_jax.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,14 @@ class JaxArrayWrapper(np.lib.mixins.NDArrayOperatorsMixin):
404404
"""Wraps a JAX array into a duck-typed array suitable for use with xarray.
405405
406406
This uses an older duck-typed array protocol based on __array_ufunc__ and
407-
__array_function__ which works with numpy and xarray. This is in the process
408-
of being superseded by the Python array API standard
409-
(https://data-apis.org/array-api/latest/index.html), but JAX and xarray
410-
haven't implemented it yet. Once they have, we should be able to get rid of
407+
__array_function__ which works with numpy and xarray. (In newer versions
408+
of xarray it implements xarray.namedarray._typing._array_function.)
409+
410+
This is in the process of being superseded by the Python array API standard
411+
(https://data-apis.org/array-api/latest/index.html), but JAX hasn't
412+
implemented it yet. Once they have, we should be able to get rid of
411413
this wrapper and use JAX arrays directly with xarray.
414+
412415
"""
413416

414417
def __init__(self, jax_array):
@@ -464,6 +467,14 @@ def ndim(self):
464467
def size(self):
465468
return self.jax_array.size
466469

470+
@property
471+
def real(self):
472+
return self.jax_array.real
473+
474+
@property
475+
def imag(self):
476+
return self.jax_array.imag
477+
467478
# Array methods not covered by NDArrayOperatorsMixin:
468479

469480
# Allows conversion to numpy array using np.asarray etc. Warning: doing this

0 commit comments

Comments
 (0)