@@ -444,6 +444,29 @@ def ConvertBatchToTF(self, batch: Any) -> Any:
444444
445445 return return_data
446446
447+ def ConvertBatchToJAX (self , batch : Any , device = None ) -> Any :
448+ """
449+ Convert a RTensor into a JAX array
450+
451+ Args:
452+ batch (RTensor): Batch returned from the DataLoader
453+
454+ Returns:
455+ jax.Array: converted batch
456+ """
457+ import jax
458+ import jax .numpy as jnp
459+
460+ split = self ._split_target_and_weights (jnp .asarray (self ._get_raw_array (batch )))
461+
462+ if isinstance (device , str ):
463+ device = jax .devices (device )[0 ]
464+
465+ return (
466+ tuple (jax .device_put (arr , device = device ) for arr in split )
467+ if isinstance (split , tuple )
468+ else jax .device_put (split , device = device )
469+ )
447470
448471 # Return a batch when available
449472 def GetTrainBatch (self ) -> Any :
@@ -745,6 +768,24 @@ def as_tensorflow(self) -> tf.data.Dataset:
745768 loader = FormattedLoader (self ._internal , self ._internal .ConvertBatchToTF , self ._is_training )
746769 return tf .data .Dataset .from_generator (lambda : loader , output_signature = batch_signature )
747770
771+ def as_jax (self , device : str | Any = None ) -> FormattedLoader :
772+ r"""
773+ \ingroup Py_ML
774+ Return an iterable that yields batches as JAX arrays.
775+
776+ Args:
777+ device: If given, the returned arrays are moved to the specified device.
778+ Can be a string (e.g. "cpu", "gpu", "tpu") or any of JAX's device objects.
779+ """
780+ try :
781+ import jax # noqa F401
782+ except ImportError :
783+ raise ImportError ("Failed to import jax needed for the ML dataloader" )
784+
785+ self ._ensure_created ()
786+ conversion_fn = lambda batch : self ._internal .ConvertBatchToJAX (batch , device ) # noqa: E731
787+ return FormattedLoader (self ._internal , conversion_fn , self ._is_training )
788+
748789 @property
749790 def columns (self ) -> list [str ]:
750791 r"""
0 commit comments