Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def __exit__(self, type, value, traceback):

# formatted iterator (returned by as_torch / as_numpy / as_tensorflow)
class FormattedLoader:
"""
r"""
\ingroup Py_ML
Iterable that converts each batch to the requested format.
Returned by the as_torch / as_numpy / as_tensorflow methods on RDataLoader.
Expand Down Expand Up @@ -550,7 +550,7 @@ def __iter__(self):


class RDataLoader:
"""
r"""
\ingroup Py_ML
Entry point for ML batch loading from a ROOT RDataFrame.

Expand Down Expand Up @@ -588,7 +588,7 @@ def __init__(
sampling_ratio: float = 1.0,
replacement: bool = False,
) -> None:
"""
r"""
\ingroup Py_ML

Args:
Expand Down Expand Up @@ -702,15 +702,15 @@ def train_test_split(self, test_size: float = 0.2) -> Tuple[RDataLoader, RDataLo
)

def as_numpy(self) -> FormattedLoader:
"""
r"""
\ingroup Py_ML
Return an iterable that yields batches as NumPy arrays.
"""
self._ensure_created()
return FormattedLoader(self._internal, self._internal.ConvertBatchToNumpy, self._is_training)

def as_torch(self, device: str | torch.device | None = None) -> FormattedLoader:
"""
r"""
\ingroup Py_ML
Return an iterable that yields batches as PyTorch tensors.

Expand All @@ -722,7 +722,7 @@ def as_torch(self, device: str | torch.device | None = None) -> FormattedLoader:
return FormattedLoader(self._internal, conversion_fn, self._is_training)

def as_tensorflow(self) -> tf.data.Dataset:
"""
r"""
\ingroup Py_ML
Return a tf.data.Dataset over batches as TensorFlow tensors.
"""
Expand Down Expand Up @@ -758,7 +758,7 @@ def as_tensorflow(self) -> tf.data.Dataset:

@property
def columns(self) -> list[str]:
"""
r"""
\ingroup Py_ML
All column names as they appear in each batch tensor.
"""
Expand All @@ -768,7 +768,7 @@ def columns(self) -> list[str]:

@property
def train_columns(self) -> list[str]:
"""
r"""
\ingroup Py_ML
Feature column names (columns minus target and weights).
"""
Expand All @@ -780,7 +780,7 @@ def train_columns(self) -> list[str]:

@property
def target_columns(self) -> list[str]:
"""
r"""
\ingroup Py_ML
Target column names.
"""
Expand All @@ -790,7 +790,7 @@ def target_columns(self) -> list[str]:

@property
def weights_column(self) -> str:
"""
r"""
\ingroup Py_ML
Weights column name, or empty string if not set.
"""
Expand All @@ -800,7 +800,7 @@ def weights_column(self) -> str:

@property
def num_batches(self) -> int:
"""
r"""
\ingroup Py_ML
Total number of batches in this split for one epoch.
"""
Expand All @@ -815,7 +815,7 @@ def num_batches(self) -> int:

@property
def last_batch_no_of_rows(self) -> int:
"""
r"""
\ingroup Py_ML
Number of rows in the last (remainder) batch, 0 if no remainder.
"""
Expand Down
Loading