Skip to content

Commit 47dd04d

Browse files
committed
[Python][ML] Add as_jax() output format to RDataLoader
1 parent 2eb1558 commit 47dd04d

1 file changed

Lines changed: 41 additions & 0 deletions

File tree

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,29 @@ def ConvertBatchToTF(self, batch: Any) -> Any:
444444

445445
return return_data
446446

447+
def ConvertBatchToJAX(self, batch: Any, device=None) -> Any:
448+
"""
449+
Convert a RTensor into a JAX array
450+
451+
Args:
452+
batch (RTensor): Batch returned from the DataLoader
453+
454+
Returns:
455+
jax.Array: converted batch
456+
"""
457+
import jax
458+
import jax.numpy as jnp
459+
460+
split = self._split_target_and_weights(jnp.asarray(self._get_raw_array(batch)))
461+
462+
if isinstance(device, str):
463+
device = jax.devices(device)[0]
464+
465+
return (
466+
tuple(jax.device_put(arr, device=device) for arr in split)
467+
if isinstance(split, tuple)
468+
else jax.device_put(split, device=device)
469+
)
447470

448471
# Return a batch when available
449472
def GetTrainBatch(self) -> Any:
@@ -745,6 +768,24 @@ def as_tensorflow(self) -> tf.data.Dataset:
745768
loader = FormattedLoader(self._internal, self._internal.ConvertBatchToTF, self._is_training)
746769
return tf.data.Dataset.from_generator(lambda: loader, output_signature=batch_signature)
747770

771+
def as_jax(self, device: str | Any = None) -> FormattedLoader:
772+
r"""
773+
\ingroup Py_ML
774+
Return an iterable that yields batches as JAX arrays.
775+
776+
Args:
777+
device: If given, the returned arrays are moved to the specified device.
778+
Can be a string (e.g. "cpu", "gpu", "tpu") or any of JAX's device objects.
779+
"""
780+
try:
781+
import jax # noqa F401
782+
except ImportError:
783+
raise ImportError("Failed to import jax needed for the ML dataloader")
784+
785+
self._ensure_created()
786+
conversion_fn = lambda batch: self._internal.ConvertBatchToJAX(batch, device) # noqa: E731
787+
return FormattedLoader(self._internal, conversion_fn, self._is_training)
788+
748789
@property
749790
def columns(self) -> list[str]:
750791
r"""

0 commit comments

Comments
 (0)