Skip to content

Commit 6c7ff83

Browse files
committed
[Python][ML] Refactor batch conversion into helpers
1 parent ed9460a commit 6c7ff83

1 file changed

Lines changed: 39 additions & 50 deletions

File tree

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py

Lines changed: 39 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -345,35 +345,29 @@ def GetSample(self):
345345
np.zeros((self.batch_size)).reshape(-1, 1),
346346
)
347347

348-
def ConvertBatchToNumpy(self, batch) -> np.ndarray:
349-
"""Convert a RTensor into a NumPy array
350-
351-
Args:
352-
batch (RTensor): Batch returned from the DataLoader
353-
354-
Returns:
355-
np.ndarray: converted batch
356-
"""
348+
def _get_raw_array(self, batch) -> np.ndarray:
357349
try:
358350
import numpy as np
359351
except ImportError:
360352
raise ImportError("Failed to import numpy needed for the ML dataloader")
361353

362354
data = batch.GetData()
363355
batch_size, num_columns = tuple(batch.GetShape())
364-
365356
data.reshape((batch_size * num_columns,))
366357

367-
return_data = np.asarray(data).reshape(batch_size, num_columns)
358+
return np.asarray(data).reshape(batch_size, num_columns)
368359

360+
def _split_target_and_weights(
361+
self, data: np.ndarray
362+
) -> np.ndarray | Tuple[np.ndarray, np.ndarray] | Tuple[np.ndarray, np.ndarray, np.ndarray]:
369363
# Splice target column from the data if target is given
370364
if self.target_given:
371-
train_data = return_data[:, self.train_indices]
372-
target_data = return_data[:, self.target_indices]
365+
train_data = data[:, self.train_indices]
366+
target_data = data[:, self.target_indices]
373367

374368
# Splice weight column from the data if weight is given
375369
if self.weights_given:
376-
weights_data = return_data[:, self.weights_index]
370+
weights_data = data[:, self.weights_index]
377371

378372
if len(self.target_indices) == 1:
379373
return train_data, target_data.reshape(-1, 1), weights_data.reshape(-1, 1)
@@ -385,7 +379,18 @@ def ConvertBatchToNumpy(self, batch) -> np.ndarray:
385379

386380
return train_data, target_data
387381

388-
return return_data
382+
return data
383+
384+
def ConvertBatchToNumpy(self, batch) -> np.ndarray:
385+
"""Convert a RTensor into a NumPy array
386+
387+
Args:
388+
batch (RTensor): Batch returned from the DataLoader
389+
390+
Returns:
391+
np.ndarray: converted batch
392+
"""
393+
return self._split_target_and_weights(self._get_raw_array(batch))
389394

390395
def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor:
391396
"""Convert a RTensor into a PyTorch tensor
@@ -396,36 +401,14 @@ def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor:
396401
Returns:
397402
torch.Tensor: converted batch
398403
"""
399-
import numpy as np
400404
import torch
401405

402-
data = batch.GetData()
403-
batch_size, num_columns = tuple(batch.GetShape())
404-
405-
data.reshape((batch_size * num_columns,))
406-
407-
return_data = torch.as_tensor(np.asarray(data), device=device).reshape(batch_size, num_columns)
408-
409-
# Splice target column from the data if target is given
410-
if self.target_given:
411-
train_data = return_data[:, self.train_indices]
412-
target_data = return_data[:, self.target_indices]
413-
414-
# Splice weight column from the data if weight is given
415-
if self.weights_given:
416-
weights_data = return_data[:, self.weights_index]
417-
418-
if len(self.target_indices) == 1:
419-
return train_data, target_data.reshape(-1, 1), weights_data.reshape(-1, 1)
420-
421-
return train_data, target_data, weights_data.reshape(-1, 1)
422-
423-
if len(self.target_indices) == 1:
424-
return train_data, target_data.reshape(-1, 1)
425-
426-
return train_data, target_data
427-
428-
return return_data
406+
split = self._split_target_and_weights(self._get_raw_array(batch))
407+
return (
408+
tuple(torch.as_tensor(arr, device=device) for arr in split)
409+
if isinstance(split, tuple)
410+
else torch.as_tensor(split, device=device)
411+
)
429412

430413
def ConvertBatchToTF(self, batch: Any) -> Any:
431414
"""
@@ -439,12 +422,9 @@ def ConvertBatchToTF(self, batch: Any) -> Any:
439422
"""
440423
import tensorflow as tf
441424

442-
data = batch.GetData()
443-
batch_size, num_columns = tuple(batch.GetShape())
444-
445-
data.reshape((batch_size * num_columns,))
446-
447-
return_data = tf.constant(data, shape=(batch_size, num_columns))
425+
arr = self._get_raw_array(batch)
426+
batch_size = arr.shape[0]
427+
return_data = tf.constant(arr)
448428

449429
if batch_size != self.batch_size:
450430
return_data = tf.pad(return_data, tf.constant([[0, self.batch_size - batch_size], [0, 0]]))
@@ -464,6 +444,7 @@ def ConvertBatchToTF(self, batch: Any) -> Any:
464444

465445
return return_data
466446

447+
467448
# Return a batch when available
468449
def GetTrainBatch(self) -> Any:
469450
"""Return the next training batch of data from the given RDataFrame
@@ -717,6 +698,11 @@ def as_torch(self, device: str | torch.device | None = None) -> FormattedLoader:
717698
Args:
718699
device: If given, the returned tensors are moved to the specified device.
719700
"""
701+
try:
702+
import torch # noqa F401
703+
except ImportError:
704+
raise ImportError("Failed to import torch needed for the ML dataloader")
705+
720706
self._ensure_created()
721707
conversion_fn = lambda batch: self._internal.ConvertBatchToPyTorch(batch, device) # noqa: E731
722708
return FormattedLoader(self._internal, conversion_fn, self._is_training)
@@ -726,7 +712,10 @@ def as_tensorflow(self) -> tf.data.Dataset:
726712
\ingroup Py_ML
727713
Return a tf.data.Dataset over batches as TensorFlow tensors.
728714
"""
729-
import tensorflow as tf
715+
try:
716+
import tensorflow as tf
717+
except ImportError:
718+
raise ImportError("Failed to import tensorflow needed for the ML dataloader")
730719

731720
self._ensure_created()
732721

0 commit comments

Comments
 (0)