@@ -382,21 +382,21 @@ def _split_target_and_weights(
382382 return data
383383
384384 def ConvertBatchToNumpy (self , batch ) -> np .ndarray :
385- """Convert a RTensor into a NumPy array
385+ """Convert the batch into a NumPy array
386386
387387 Args:
388- batch (RTensor) : Batch returned from the DataLoader
388+ batch: Batch returned from the DataLoader
389389
390390 Returns:
391391 np.ndarray: converted batch
392392 """
393393 return self ._split_target_and_weights (self ._get_raw_array (batch ))
394394
395395 def ConvertBatchToPyTorch (self , batch : Any , device = None ) -> torch .Tensor :
396- """Convert a RTensor into a PyTorch tensor
396+ """Convert the batch into a PyTorch tensor
397397
398398 Args:
399- batch (RTensor) : Batch returned from the DataLoader
399+ batch: Batch returned from the DataLoader
400400
401401 Returns:
402402 torch.Tensor: converted batch
@@ -412,10 +412,10 @@ def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor:
412412
413413 def ConvertBatchToTF (self , batch : Any ) -> Any :
414414 """
415- Convert a RTensor into a TensorFlow tensor
415+ Convert the batch into a TensorFlow tensor
416416
417417 Args:
418- batch (RTensor) : Batch returned from the DataLoader
418+ batch: Batch returned from the DataLoader
419419
420420 Returns:
421421 tensorflow.Tensor: converted batch
@@ -444,6 +444,29 @@ def ConvertBatchToTF(self, batch: Any) -> Any:
444444
445445 return return_data
446446
447+ def ConvertBatchToJAX (self , batch : Any , device = None ) -> Any :
448+ """
449+ Convert the batch into a JAX array
450+
451+ Args:
452+ batch: Batch returned from the DataLoader
453+
454+ Returns:
455+ jax.Array: converted batch
456+ """
457+ import jax
458+ import jax .numpy as jnp
459+
460+ split = self ._split_target_and_weights (jnp .asarray (self ._get_raw_array (batch )))
461+
462+ if isinstance (device , str ):
463+ device = jax .devices (device )[0 ]
464+
465+ return (
466+ tuple (jax .device_put (arr , device = device ) for arr in split )
467+ if isinstance (split , tuple )
468+ else jax .device_put (split , device = device )
469+ )
447470
448471 # Return a batch when available
449472 def GetTrainBatch (self ) -> Any :
@@ -745,6 +768,24 @@ def as_tensorflow(self) -> tf.data.Dataset:
745768 loader = FormattedLoader (self ._internal , self ._internal .ConvertBatchToTF , self ._is_training )
746769 return tf .data .Dataset .from_generator (lambda : loader , output_signature = batch_signature )
747770
771+ def as_jax (self , device : str | Any = None ) -> FormattedLoader :
772+ r"""
773+ \ingroup Py_ML
774+ Return an iterable that yields batches as JAX arrays.
775+
776+ Args:
777+ device: If given, the returned arrays are moved to the specified device.
778+ Can be a string (e.g. "cpu", "gpu", "tpu") or any of JAX's device objects.
779+ """
780+ try :
781+ import jax # noqa F401
782+ except ImportError :
783+ raise ImportError ("Failed to import jax needed for the ML dataloader" )
784+
785+ self ._ensure_created ()
786+ conversion_fn = lambda batch : self ._internal .ConvertBatchToJAX (batch , device ) # noqa: E731
787+ return FormattedLoader (self ._internal , conversion_fn , self ._is_training )
788+
748789 @property
749790 def columns (self ) -> list [str ]:
750791 r"""
0 commit comments