@@ -345,35 +345,29 @@ def GetSample(self):
345345 np .zeros ((self .batch_size )).reshape (- 1 , 1 ),
346346 )
347347
348- def ConvertBatchToNumpy (self , batch ) -> np .ndarray :
349- """Convert a RTensor into a NumPy array
350-
351- Args:
352- batch (RTensor): Batch returned from the DataLoader
353-
354- Returns:
355- np.ndarray: converted batch
356- """
348+ def _get_raw_array (self , batch ) -> np .ndarray :
357349 try :
358350 import numpy as np
359351 except ImportError :
360352 raise ImportError ("Failed to import numpy needed for the ML dataloader" )
361353
362354 data = batch .GetData ()
363355 batch_size , num_columns = tuple (batch .GetShape ())
364-
365356 data .reshape ((batch_size * num_columns ,))
366357
367- return_data = np .asarray (data ).reshape (batch_size , num_columns )
358+ return np .asarray (data ).reshape (batch_size , num_columns )
368359
360+ def _split_target_and_weights (
361+ self , data : np .ndarray
362+ ) -> np .ndarray | Tuple [np .ndarray , np .ndarray ] | Tuple [np .ndarray , np .ndarray , np .ndarray ]:
369363 # Splice target column from the data if target is given
370364 if self .target_given :
371- train_data = return_data [:, self .train_indices ]
372- target_data = return_data [:, self .target_indices ]
365+ train_data = data [:, self .train_indices ]
366+ target_data = data [:, self .target_indices ]
373367
374368 # Splice weight column from the data if weight is given
375369 if self .weights_given :
376- weights_data = return_data [:, self .weights_index ]
370+ weights_data = data [:, self .weights_index ]
377371
378372 if len (self .target_indices ) == 1 :
379373 return train_data , target_data .reshape (- 1 , 1 ), weights_data .reshape (- 1 , 1 )
@@ -385,7 +379,18 @@ def ConvertBatchToNumpy(self, batch) -> np.ndarray:
385379
386380 return train_data , target_data
387381
388- return return_data
382+ return data
383+
384+ def ConvertBatchToNumpy (self , batch ) -> np .ndarray :
385+ """Convert a RTensor into a NumPy array
386+
387+ Args:
388+ batch (RTensor): Batch returned from the DataLoader
389+
390+ Returns:
391+ np.ndarray: converted batch
392+ """
393+ return self ._split_target_and_weights (self ._get_raw_array (batch ))
389394
390395 def ConvertBatchToPyTorch (self , batch : Any , device = None ) -> torch .Tensor :
391396 """Convert a RTensor into a PyTorch tensor
@@ -396,36 +401,14 @@ def ConvertBatchToPyTorch(self, batch: Any, device=None) -> torch.Tensor:
396401 Returns:
397402 torch.Tensor: converted batch
398403 """
399- import numpy as np
400404 import torch
401405
402- data = batch .GetData ()
403- batch_size , num_columns = tuple (batch .GetShape ())
404-
405- data .reshape ((batch_size * num_columns ,))
406-
407- return_data = torch .as_tensor (np .asarray (data ), device = device ).reshape (batch_size , num_columns )
408-
409- # Splice target column from the data if target is given
410- if self .target_given :
411- train_data = return_data [:, self .train_indices ]
412- target_data = return_data [:, self .target_indices ]
413-
414- # Splice weight column from the data if weight is given
415- if self .weights_given :
416- weights_data = return_data [:, self .weights_index ]
417-
418- if len (self .target_indices ) == 1 :
419- return train_data , target_data .reshape (- 1 , 1 ), weights_data .reshape (- 1 , 1 )
420-
421- return train_data , target_data , weights_data .reshape (- 1 , 1 )
422-
423- if len (self .target_indices ) == 1 :
424- return train_data , target_data .reshape (- 1 , 1 )
425-
426- return train_data , target_data
427-
428- return return_data
406+ split = self ._split_target_and_weights (self ._get_raw_array (batch ))
407+ return (
408+ tuple (torch .as_tensor (arr , device = device ) for arr in split )
409+ if isinstance (split , tuple )
410+ else torch .as_tensor (split , device = device )
411+ )
429412
430413 def ConvertBatchToTF (self , batch : Any ) -> Any :
431414 """
@@ -439,12 +422,9 @@ def ConvertBatchToTF(self, batch: Any) -> Any:
439422 """
440423 import tensorflow as tf
441424
442- data = batch .GetData ()
443- batch_size , num_columns = tuple (batch .GetShape ())
444-
445- data .reshape ((batch_size * num_columns ,))
446-
447- return_data = tf .constant (data , shape = (batch_size , num_columns ))
425+ arr = self ._get_raw_array (batch )
426+ batch_size = arr .shape [0 ]
427+ return_data = tf .constant (arr )
448428
449429 if batch_size != self .batch_size :
450430 return_data = tf .pad (return_data , tf .constant ([[0 , self .batch_size - batch_size ], [0 , 0 ]]))
@@ -464,6 +444,7 @@ def ConvertBatchToTF(self, batch: Any) -> Any:
464444
465445 return return_data
466446
447+
467448 # Return a batch when available
468449 def GetTrainBatch (self ) -> Any :
469450 """Return the next training batch of data from the given RDataFrame
@@ -717,6 +698,11 @@ def as_torch(self, device: str | torch.device | None = None) -> FormattedLoader:
717698 Args:
718699 device: If given, the returned tensors are moved to the specified device.
719700 """
701+ try :
702+ import torch # noqa F401
703+ except ImportError :
704+ raise ImportError ("Failed to import torch needed for the ML dataloader" )
705+
720706 self ._ensure_created ()
721707 conversion_fn = lambda batch : self ._internal .ConvertBatchToPyTorch (batch , device ) # noqa: E731
722708 return FormattedLoader (self ._internal , conversion_fn , self ._is_training )
@@ -726,7 +712,10 @@ def as_tensorflow(self) -> tf.data.Dataset:
726712 \ingroup Py_ML
727713 Return a tf.data.Dataset over batches as TensorFlow tensors.
728714 """
729- import tensorflow as tf
715+ try :
716+ import tensorflow as tf
717+ except ImportError :
718+ raise ImportError ("Failed to import tensorflow needed for the ML dataloader" )
730719
731720 self ._ensure_created ()
732721
0 commit comments