Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -345,35 +345,29 @@ def GetSample(self):
np.zeros((self.batch_size)).reshape(-1, 1),
)

def ConvertBatchToNumpy(self, batch) -> np.ndarray:
"""Convert a RTensor into a NumPy array

Args:
batch (RTensor): Batch returned from the DataLoader

Returns:
np.ndarray: converted batch
"""
def _get_raw_array(self, batch) -> np.ndarray:
try:
import numpy as np
except ImportError:
raise ImportError("Failed to import numpy needed for the ML dataloader")

data = batch.GetData()
batch_size, num_columns = tuple(batch.GetShape())

data.reshape((batch_size * num_columns,))

return_data = np.asarray(data).reshape(batch_size, num_columns)
return np.asarray(data).reshape(batch_size, num_columns)

def _split_target_and_weights(
self, data: np.ndarray
) -> np.ndarray | Tuple[np.ndarray, np.ndarray] | Tuple[np.ndarray, np.ndarray, np.ndarray]:
# Splice target column from the data if target is given
if self.target_given:
train_data = return_data[:, self.train_indices]
target_data = return_data[:, self.target_indices]
train_data = data[:, self.train_indices]
target_data = data[:, self.target_indices]

# Splice weight column from the data if weight is given
if self.weights_given:
weights_data = return_data[:, self.weights_index]
weights_data = data[:, self.weights_index]

if len(self.target_indices) == 1:
return train_data, target_data.reshape(-1, 1), weights_data.reshape(-1, 1)
Expand All @@ -385,66 +379,52 @@ def ConvertBatchToNumpy(self, batch) -> np.ndarray:

return train_data, target_data

return return_data
return data

def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor:
"""Convert a RTensor into a PyTorch tensor
def ConvertBatchToNumpy(self, batch) -> np.ndarray:
"""Convert the batch into a NumPy array

Args:
batch (RTensor): Batch returned from the DataLoader
batch: Batch returned from the DataLoader

Returns:
torch.Tensor: converted batch
np.ndarray: converted batch
"""
import numpy as np
import torch

data = batch.GetData()
batch_size, num_columns = tuple(batch.GetShape())

data.reshape((batch_size * num_columns,))
return self._split_target_and_weights(self._get_raw_array(batch))

return_data = torch.as_tensor(np.asarray(data), device=device).reshape(batch_size, num_columns)

# Splice target column from the data if target is given
if self.target_given:
train_data = return_data[:, self.train_indices]
target_data = return_data[:, self.target_indices]

# Splice weight column from the data if weight is given
if self.weights_given:
weights_data = return_data[:, self.weights_index]

if len(self.target_indices) == 1:
return train_data, target_data.reshape(-1, 1), weights_data.reshape(-1, 1)

return train_data, target_data, weights_data.reshape(-1, 1)
def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor:
"""Convert the batch into a PyTorch tensor

if len(self.target_indices) == 1:
return train_data, target_data.reshape(-1, 1)
Args:
batch: Batch returned from the DataLoader

return train_data, target_data
Returns:
torch.Tensor: converted batch
"""
import torch

return return_data
split = self._split_target_and_weights(self._get_raw_array(batch))
return (
tuple(torch.as_tensor(arr, device=device) for arr in split)
if isinstance(split, tuple)
else torch.as_tensor(split, device=device)
)

def ConvertBatchToTF(self, batch: Any) -> Any:
"""
Convert a RTensor into a TensorFlow tensor
Convert the batch into a TensorFlow tensor

Args:
batch (RTensor): Batch returned from the DataLoader
batch: Batch returned from the DataLoader

Returns:
tensorflow.Tensor: converted batch
"""
import tensorflow as tf

data = batch.GetData()
batch_size, num_columns = tuple(batch.GetShape())

data.reshape((batch_size * num_columns,))

return_data = tf.constant(data, shape=(batch_size, num_columns))
arr = self._get_raw_array(batch)
batch_size = arr.shape[0]
return_data = tf.constant(arr)

if batch_size != self.batch_size:
return_data = tf.pad(return_data, tf.constant([[0, self.batch_size - batch_size], [0, 0]]))
Expand All @@ -464,6 +444,30 @@ def ConvertBatchToTF(self, batch: Any) -> Any:

return return_data

def ConvertBatchToJAX(self, batch: Any, device=None) -> Any:
"""
Convert the batch into a JAX array

Args:
batch: Batch returned from the DataLoader

Returns:
jax.Array: converted batch
"""
import jax
import jax.numpy as jnp

split = self._split_target_and_weights(jnp.asarray(self._get_raw_array(batch)))

if isinstance(device, str):
device = jax.devices(device)[0]

return (
tuple(jax.device_put(arr, device=device) for arr in split)
if isinstance(split, tuple)
else jax.device_put(split, device=device)
)

# Return a batch when available
def GetTrainBatch(self) -> Any:
"""Return the next training batch of data from the given RDataFrame
Expand Down Expand Up @@ -717,6 +721,11 @@ def as_torch(self, device: str | torch.device | None = None) -> FormattedLoader:
Args:
device: If given, the returned tensors are moved to the specified device.
"""
try:
import torch # noqa F401
except ImportError:
raise ImportError("Failed to import torch needed for the ML dataloader")

self._ensure_created()
conversion_fn = lambda batch: self._internal.ConvertBatchToPyTorch(batch, device) # noqa: E731
return FormattedLoader(self._internal, conversion_fn, self._is_training)
Expand All @@ -726,7 +735,10 @@ def as_tensorflow(self) -> tf.data.Dataset:
\ingroup Py_ML
Return a tf.data.Dataset over batches as TensorFlow tensors.
"""
import tensorflow as tf
try:
import tensorflow as tf
except ImportError:
raise ImportError("Failed to import tensorflow needed for the ML dataloader")

self._ensure_created()

Expand Down Expand Up @@ -756,6 +768,24 @@ def as_tensorflow(self) -> tf.data.Dataset:
loader = FormattedLoader(self._internal, self._internal.ConvertBatchToTF, self._is_training)
return tf.data.Dataset.from_generator(lambda: loader, output_signature=batch_signature)

def as_jax(self, device: str | Any = None) -> FormattedLoader:
r"""
\ingroup Py_ML
Return an iterable that yields batches as JAX arrays.

Args:
device: If given, the returned arrays are moved to the specified device.
Can be a string (e.g. "cpu", "gpu", "tpu") or any of JAX's device objects.
"""
try:
import jax # noqa F401
except ImportError:
raise ImportError("Failed to import jax needed for the ML dataloader")

self._ensure_created()
conversion_fn = lambda batch: self._internal.ConvertBatchToJAX(batch, device) # noqa: E731
return FormattedLoader(self._internal, conversion_fn, self._is_training)

@property
def columns(self) -> list[str]:
r"""
Expand Down
82 changes: 82 additions & 0 deletions bindings/pyroot/pythonizations/test/ml_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,6 +1116,88 @@ def test16_vector_padding(self):
self.teardown_file(self.file_name3)
raise

def test17_JAX(self):
file_name = "multiple_target_columns.root"

ROOT.RDataFrame(10).Define("b1", "(Short_t) rdfentry_").Define("b2", "(UShort_t) b1 * b1").Define(
"b3", "(double) rdfentry_ * 10"
).Define("b4", "(double) b3 * 10").Snapshot("myTree", file_name)

try:
df = ROOT.RDataFrame("myTree", file_name)

dl = ROOT.Experimental.ML.RDataLoader(
df,
batch_size=3,
batches_in_memory=2,
target=["b2", "b4"],
weights="b3",
shuffle=False,
drop_remainder=False,
)

gen_train, gen_validation = dl.train_test_split(0.4)

results_x_train = [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]
results_x_val = [6.0, 7.0, 8.0, 9.0]
results_y_train = [0.0, 0.0, 1.0, 100.0, 4.0, 200.0, 9.0, 300.0, 16.0, 400.0, 25.0, 500.0]
results_y_val = [36.0, 600.0, 49.0, 700.0, 64.0, 800.0, 81.0, 900.0]
results_z_train = [0.0, 10.0, 20.0, 30.0, 40.0, 50.0]
results_z_val = [60.0, 70.0, 80.0, 90.0]

collected_x_train = []
collected_x_val = []
collected_y_train = []
collected_y_val = []
collected_z_train = []
collected_z_val = []

iter_train = iter(gen_train.as_jax(device="cpu"))
iter_val = iter(gen_validation.as_jax())

for _ in range(self.n_train_batch):
x, y, z = next(iter_train)
self.assertTrue(x.shape == (3, 1))
self.assertTrue(y.shape == (3, 2))
self.assertTrue(z.shape == (3, 1))
collected_x_train.append(x.tolist())
collected_y_train.append(y.tolist())
collected_z_train.append(z.tolist())

for _ in range(self.n_val_batch):
x, y, z = next(iter_val)
self.assertTrue(x.shape == (3, 1))
self.assertTrue(y.shape == (3, 2))
self.assertTrue(z.shape == (3, 1))
collected_x_val.append(x.tolist())
collected_y_val.append(y.tolist())
collected_z_val.append(z.tolist())

x, y, z = next(iter_val)
self.assertTrue(x.shape == (self.val_remainder, 1))
self.assertTrue(y.shape == (self.val_remainder, 2))
self.assertTrue(z.shape == (self.val_remainder, 1))
collected_x_val.append(x.tolist())
collected_y_val.append(y.tolist())
collected_z_val.append(z.tolist())

flat_x_train = [x for xl in collected_x_train for xs in xl for x in xs]
flat_x_val = [x for xl in collected_x_val for xs in xl for x in xs]
flat_y_train = [y for yl in collected_y_train for ys in yl for y in ys]
flat_y_val = [y for yl in collected_y_val for ys in yl for y in ys]
flat_z_train = [z for zl in collected_z_train for zs in zl for z in zs]
flat_z_val = [z for zl in collected_z_val for zs in zl for z in zs]

self.assertEqual(results_x_train, flat_x_train)
self.assertEqual(results_x_val, flat_x_val)
self.assertEqual(results_y_train, flat_y_train)
self.assertEqual(results_y_val, flat_y_val)
self.assertEqual(results_z_train, flat_z_train)
self.assertEqual(results_z_val, flat_z_val)

finally:
self.teardown_file(file_name)


class DataLoaderEagerLoading(unittest.TestCase):
file_name1 = "first_half.root"
Expand Down
Loading