From 2eb1558704be2394aa59930b3a022c0922e590c0 Mon Sep 17 00:00:00 2001 From: Silia Taider Date: Tue, 9 Jun 2026 16:32:00 +0200 Subject: [PATCH 1/3] [Python][ML] Refactor batch conversion into helpers --- .../ROOT/_pythonization/_ml_dataloader.py | 89 ++++++++----------- 1 file changed, 39 insertions(+), 50 deletions(-) diff --git a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py index 3ab53859622e4..db00746864ffd 100644 --- a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py +++ b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py @@ -345,15 +345,7 @@ def GetSample(self): np.zeros((self.batch_size)).reshape(-1, 1), ) - def ConvertBatchToNumpy(self, batch) -> np.ndarray: - """Convert a RTensor into a NumPy array - - Args: - batch (RTensor): Batch returned from the DataLoader - - Returns: - np.ndarray: converted batch - """ + def _get_raw_array(self, batch) -> np.ndarray: try: import numpy as np except ImportError: @@ -361,19 +353,21 @@ def ConvertBatchToNumpy(self, batch) -> np.ndarray: data = batch.GetData() batch_size, num_columns = tuple(batch.GetShape()) - data.reshape((batch_size * num_columns,)) - return_data = np.asarray(data).reshape(batch_size, num_columns) + return np.asarray(data).reshape(batch_size, num_columns) + def _split_target_and_weights( + self, data: np.ndarray + ) -> np.ndarray | Tuple[np.ndarray, np.ndarray] | Tuple[np.ndarray, np.ndarray, np.ndarray]: # Splice target column from the data if target is given if self.target_given: - train_data = return_data[:, self.train_indices] - target_data = return_data[:, self.target_indices] + train_data = data[:, self.train_indices] + target_data = data[:, self.target_indices] # Splice weight column from the data if weight is given if self.weights_given: - weights_data = return_data[:, self.weights_index] + weights_data = data[:, self.weights_index] if len(self.target_indices) == 1: return train_data, target_data.reshape(-1, 1), weights_data.reshape(-1, 1) @@ -385,7 +379,18 @@ def ConvertBatchToNumpy(self, batch) -> np.ndarray: return train_data, target_data - return return_data + return data + + def ConvertBatchToNumpy(self, batch) -> np.ndarray: + """Convert a RTensor into a NumPy array + + Args: + batch (RTensor): Batch returned from the DataLoader + + Returns: + np.ndarray: converted batch + """ + return self._split_target_and_weights(self._get_raw_array(batch)) def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor: """Convert a RTensor into a PyTorch tensor @@ -396,36 +401,14 @@ def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor: Returns: torch.Tensor: converted batch """ - import numpy as np import torch - data = batch.GetData() - batch_size, num_columns = tuple(batch.GetShape()) - - data.reshape((batch_size * num_columns,)) - - return_data = torch.as_tensor(np.asarray(data), device=device).reshape(batch_size, num_columns) - - # Splice target column from the data if target is given - if self.target_given: - train_data = return_data[:, self.train_indices] - target_data = return_data[:, self.target_indices] - - # Splice weight column from the data if weight is given - if self.weights_given: - weights_data = return_data[:, self.weights_index] - - if len(self.target_indices) == 1: - return train_data, target_data.reshape(-1, 1), weights_data.reshape(-1, 1) - - return train_data, target_data, weights_data.reshape(-1, 1) - - if len(self.target_indices) == 1: - return train_data, target_data.reshape(-1, 1) - - return train_data, target_data - - return return_data + split = self._split_target_and_weights(self._get_raw_array(batch)) + return ( + tuple(torch.as_tensor(arr, device=device) for arr in split) + if isinstance(split, tuple) + else torch.as_tensor(split, device=device) + ) def ConvertBatchToTF(self, batch: Any) -> Any: """ @@ -439,12 +422,9 @@ def ConvertBatchToTF(self, batch: Any) -> Any: """ import tensorflow as tf - data = batch.GetData() - batch_size, num_columns = tuple(batch.GetShape()) - - data.reshape((batch_size * num_columns,)) - - return_data = tf.constant(data, shape=(batch_size, num_columns)) + arr = self._get_raw_array(batch) + batch_size = arr.shape[0] + return_data = tf.constant(arr) if batch_size != self.batch_size: return_data = tf.pad(return_data, tf.constant([[0, self.batch_size - batch_size], [0, 0]])) @@ -464,6 +444,7 @@ def ConvertBatchToTF(self, batch: Any) -> Any: return return_data + # Return a batch when available def GetTrainBatch(self) -> Any: """Return the next training batch of data from the given RDataFrame @@ -717,6 +698,11 @@ def as_torch(self, device: str | torch.device | None = None) -> FormattedLoader: Args: device: If given, the returned tensors are moved to the specified device. """ + try: + import torch # noqa F401 + except ImportError: + raise ImportError("Failed to import torch needed for the ML dataloader") + self._ensure_created() conversion_fn = lambda batch: self._internal.ConvertBatchToPyTorch(batch, device) # noqa: E731 return FormattedLoader(self._internal, conversion_fn, self._is_training) @@ -726,7 +712,10 @@ def as_tensorflow(self) -> tf.data.Dataset: \ingroup Py_ML Return a tf.data.Dataset over batches as TensorFlow tensors. """ - import tensorflow as tf + try: + import tensorflow as tf + except ImportError: + raise ImportError("Failed to import tensorflow needed for the ML dataloader") self._ensure_created() From e00bd975883af7a1241ef2b022dfbd0c8a3926a4 Mon Sep 17 00:00:00 2001 From: Silia Taider Date: Tue, 9 Jun 2026 17:19:42 +0200 Subject: [PATCH 2/3] [Python][ML] Add as_jax() output format to RDataLoader --- .../ROOT/_pythonization/_ml_dataloader.py | 53 ++++++++++++++++--- 1 file changed, 47 insertions(+), 6 deletions(-) diff --git a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py index db00746864ffd..c54339ceef47f 100644 --- a/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py +++ b/bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py @@ -382,10 +382,10 @@ def _split_target_and_weights( return data def ConvertBatchToNumpy(self, batch) -> np.ndarray: - """Convert a RTensor into a NumPy array + """Convert the batch into a NumPy array Args: - batch (RTensor): Batch returned from the DataLoader + batch: Batch returned from the DataLoader Returns: np.ndarray: converted batch @@ -393,10 +393,10 @@ def ConvertBatchToNumpy(self, batch) -> np.ndarray: return self._split_target_and_weights(self._get_raw_array(batch)) def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor: - """Convert a RTensor into a PyTorch tensor + """Convert the batch into a PyTorch tensor Args: - batch (RTensor): Batch returned from the DataLoader + batch: Batch returned from the DataLoader Returns: torch.Tensor: converted batch @@ -412,10 +412,10 @@ def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor: def ConvertBatchToTF(self, batch: Any) -> Any: """ - Convert a RTensor into a TensorFlow tensor + Convert the batch into a TensorFlow tensor Args: - batch (RTensor): Batch returned from the DataLoader + batch: Batch returned from the DataLoader Returns: tensorflow.Tensor: converted batch @@ -444,6 +444,29 @@ def ConvertBatchToTF(self, batch: Any) -> Any: return return_data + def ConvertBatchToJAX(self, batch: Any, device=None) -> Any: + """ + Convert the batch into a JAX array + + Args: + batch: Batch returned from the DataLoader + + Returns: + jax.Array: converted batch + """ + import jax + import jax.numpy as jnp + + split = self._split_target_and_weights(jnp.asarray(self._get_raw_array(batch))) + + if isinstance(device, str): + device = jax.devices(device)[0] + + return ( + tuple(jax.device_put(arr, device=device) for arr in split) + if isinstance(split, tuple) + else jax.device_put(split, device=device) + ) # Return a batch when available def GetTrainBatch(self) -> Any: @@ -745,6 +768,24 @@ def as_tensorflow(self) -> tf.data.Dataset: loader = FormattedLoader(self._internal, self._internal.ConvertBatchToTF, self._is_training) return tf.data.Dataset.from_generator(lambda: loader, output_signature=batch_signature) + def as_jax(self, device: str | Any = None) -> FormattedLoader: + r""" + \ingroup Py_ML + Return an iterable that yields batches as JAX arrays. + + Args: + device: If given, the returned arrays are moved to the specified device. + Can be a string (e.g. "cpu", "gpu", "tpu") or any of JAX's device objects. + """ + try: + import jax # noqa F401 + except ImportError: + raise ImportError("Failed to import jax needed for the ML dataloader") + + self._ensure_created() + conversion_fn = lambda batch: self._internal.ConvertBatchToJAX(batch, device) # noqa: E731 + return FormattedLoader(self._internal, conversion_fn, self._is_training) + @property def columns(self) -> list[str]: r""" From 781a64454663275b6901b8375053db1df110eef1 Mon Sep 17 00:00:00 2001 From: Silia Taider Date: Tue, 9 Jun 2026 17:49:24 +0200 Subject: [PATCH 3/3] [Python][ML] Add test for the jax output format --- .../pythonizations/test/ml_dataloader.py | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/bindings/pyroot/pythonizations/test/ml_dataloader.py b/bindings/pyroot/pythonizations/test/ml_dataloader.py index c03143dc1c9ab..93bbb58354a61 100644 --- a/bindings/pyroot/pythonizations/test/ml_dataloader.py +++ b/bindings/pyroot/pythonizations/test/ml_dataloader.py @@ -1116,6 +1116,88 @@ def test16_vector_padding(self): self.teardown_file(self.file_name3) raise + def test17_JAX(self): + file_name = "multiple_target_columns.root" + + ROOT.RDataFrame(10).Define("b1", "(Short_t) rdfentry_").Define("b2", "(UShort_t) b1 * b1").Define( + "b3", "(double) rdfentry_ * 10" + ).Define("b4", "(double) b3 * 10").Snapshot("myTree", file_name) + + try: + df = ROOT.RDataFrame("myTree", file_name) + + dl = ROOT.Experimental.ML.RDataLoader( + df, + batch_size=3, + batches_in_memory=2, + target=["b2", "b4"], + weights="b3", + shuffle=False, + drop_remainder=False, + ) + + gen_train, gen_validation = dl.train_test_split(0.4) + + results_x_train = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0] + results_x_val = [6.0, 7.0, 8.0, 9.0] + results_y_train = [0.0, 0.0, 1.0, 100.0, 4.0, 200.0, 9.0, 300.0, 16.0, 400.0, 25.0, 500.0] + results_y_val = [36.0, 600.0, 49.0, 700.0, 64.0, 800.0, 81.0, 900.0] + results_z_train = [0.0, 10.0, 20.0, 30.0, 40.0, 50.0] + results_z_val = [60.0, 70.0, 80.0, 90.0] + + collected_x_train = [] + collected_x_val = [] + collected_y_train = [] + collected_y_val = [] + collected_z_train = [] + collected_z_val = [] + + iter_train = iter(gen_train.as_jax(device="cpu")) + iter_val = iter(gen_validation.as_jax()) + + for _ in range(self.n_train_batch): + x, y, z = next(iter_train) + self.assertTrue(x.shape == (3, 1)) + self.assertTrue(y.shape == (3, 2)) + self.assertTrue(z.shape == (3, 1)) + collected_x_train.append(x.tolist()) + collected_y_train.append(y.tolist()) + collected_z_train.append(z.tolist()) + + for _ in range(self.n_val_batch): + x, y, z = next(iter_val) + self.assertTrue(x.shape == (3, 1)) + self.assertTrue(y.shape == (3, 2)) + self.assertTrue(z.shape == (3, 1)) + collected_x_val.append(x.tolist()) + collected_y_val.append(y.tolist()) + collected_z_val.append(z.tolist()) + + x, y, z = next(iter_val) + self.assertTrue(x.shape == (self.val_remainder, 1)) + self.assertTrue(y.shape == (self.val_remainder, 2)) + self.assertTrue(z.shape == (self.val_remainder, 1)) + collected_x_val.append(x.tolist()) + collected_y_val.append(y.tolist()) + collected_z_val.append(z.tolist()) + + flat_x_train = [x for xl in collected_x_train for xs in xl for x in xs] + flat_x_val = [x for xl in collected_x_val for xs in xl for x in xs] + flat_y_train = [y for yl in collected_y_train for ys in yl for y in ys] + flat_y_val = [y for yl in collected_y_val for ys in yl for y in ys] + flat_z_train = [z for zl in collected_z_train for zs in zl for z in zs] + flat_z_val = [z for zl in collected_z_val for zs in zl for z in zs] + + self.assertEqual(results_x_train, flat_x_train) + self.assertEqual(results_x_val, flat_x_val) + self.assertEqual(results_y_train, flat_y_train) + self.assertEqual(results_y_val, flat_y_val) + self.assertEqual(results_z_train, flat_z_train) + self.assertEqual(results_z_val, flat_z_val) + + finally: + self.teardown_file(file_name) + class DataLoaderEagerLoading(unittest.TestCase): file_name1 = "first_half.root"