Commit 8debd72
Make xarray_jax.JaxArrayWrapper fully compliant with the xarray.namedarray._typing._array_function protocol from more recent versions of xarray (see https://github.com/pydata/xarray/blob/693f0b91b4381f5a672cb93ff8113abd1dc4957c/xarray/namedarray/_typing.py#L114).
This is required for more recent versions of xarray to recognise it as a duck-typed array. Otherwise xarray will try to convert it to a numpy array in some situations, in particular reductions which now go via NamedArray. This causes problems if done to a jax tracer.
The only change required to fulfill the protocol was to add `real` and `imag` properties to JaxArrayWrapper.
PiperOrigin-RevId: 595985462
Change-Id: I4c844b5fd7d7787f0e35cc210979d679ae8210441 parent 96de917 commit 8debd72
1 file changed
Lines changed: 15 additions & 4 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
404 | 404 | | |
405 | 405 | | |
406 | 406 | | |
407 | | - | |
408 | | - | |
409 | | - | |
410 | | - | |
| 407 | + | |
| 408 | + | |
| 409 | + | |
| 410 | + | |
| 411 | + | |
| 412 | + | |
411 | 413 | | |
| 414 | + | |
412 | 415 | | |
413 | 416 | | |
414 | 417 | | |
| |||
464 | 467 | | |
465 | 468 | | |
466 | 469 | | |
| 470 | + | |
| 471 | + | |
| 472 | + | |
| 473 | + | |
| 474 | + | |
| 475 | + | |
| 476 | + | |
| 477 | + | |
467 | 478 | | |
468 | 479 | | |
469 | 480 | | |
| |||
0 commit comments