Skip to content

Commit c4e18aa

Browse files
committed
[doc][Python] Add doxygen docstrings to the ML section + refactoring
1 parent cf59381 commit c4e18aa

3 files changed

Lines changed: 43 additions & 45 deletions

File tree

bindings/pyroot/pythonizations/doc/index.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ h.Fill(data)
100100
101101
# Write it to a ROOT file
102102
with ROOT.TFile.Open("output.root", "RECREATE") as f:
103-
h.Write()
103+
f.WriteObject(h, "my_histogram")
104104
~~~
105105

106106
Now we create an RDataFrame from scratch, define a new column with a Python lambda and draw a histogram:
@@ -111,8 +111,8 @@ import numpy as np
111111
# Create an RDataFrame with 10000 rows
112112
rdf = ROOT.RDataFrame(10000)
113113
114-
# Define a column x
115-
rdf = rdf.Define("x", lambda : np.random.normal(0, 1))
114+
# Define a column x representing a normal distribution
115+
rdf = rdf.Define("x", "gRandom->Gaus(0, 1)")
116116
117117
# Draw a histogram of x
118118
rdf.Histo1D("x").Draw()

bindings/pyroot/pythonizations/python/ROOT/_pythonization/_ml_dataloader.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ def __exit__(self, type, value, traceback):
519519
# formatted iterator (returned by as_torch / as_numpy / as_tensorflow)
520520
class FormattedLoader:
521521
"""
522+
\ingroup Py_ML
522523
Iterable that converts each batch to the requested format.
523524
Returned by the as_torch / as_numpy / as_tensorflow methods on RDataLoader.
524525
"""
@@ -550,6 +551,7 @@ def __iter__(self):
550551

551552
class RDataLoader:
552553
"""
554+
\ingroup Py_ML
553555
Entry point for ML batch loading from a ROOT RDataFrame.
554556
555557
Usage without a validation split::
@@ -587,6 +589,8 @@ def __init__(
587589
replacement: bool = False,
588590
) -> None:
589591
"""
592+
\ingroup Py_ML
593+
590594
Args:
591595
rdataframes:
592596
RDataFrame or list of RDataFrames to load from.
@@ -699,13 +703,15 @@ def train_test_split(self, test_size: float = 0.2) -> Tuple[RDataLoader, RDataLo
699703

700704
def as_numpy(self) -> FormattedLoader:
701705
"""
706+
\ingroup Py_ML
702707
Return an iterable that yields batches as NumPy arrays.
703708
"""
704709
self._ensure_created()
705710
return FormattedLoader(self._internal, self._internal.ConvertBatchToNumpy, self._is_training)
706711

707712
def as_torch(self, device: str | torch.device | None = None) -> FormattedLoader:
708713
"""
714+
\ingroup Py_ML
709715
Return an iterable that yields batches as PyTorch tensors.
710716
711717
Args:
@@ -717,6 +723,7 @@ def as_torch(self, device: str | torch.device | None = None) -> FormattedLoader:
717723

718724
def as_tensorflow(self) -> tf.data.Dataset:
719725
"""
726+
\ingroup Py_ML
720727
Return a tf.data.Dataset over batches as TensorFlow tensors.
721728
"""
722729
import tensorflow as tf
@@ -751,14 +758,20 @@ def as_tensorflow(self) -> tf.data.Dataset:
751758

752759
@property
753760
def columns(self) -> list[str]:
754-
"""All column names as they appear in each batch tensor."""
761+
"""
762+
\ingroup Py_ML
763+
All column names as they appear in each batch tensor.
764+
"""
755765
if self._internal is None:
756766
return self._params["columns"]
757767
return self._internal.all_columns
758768

759769
@property
760770
def train_columns(self) -> list[str]:
761-
"""Feature column names (columns minus target and weights)."""
771+
"""
772+
\ingroup Py_ML
773+
Feature column names (columns minus target and weights).
774+
"""
762775
if self._internal is None:
763776
target = self._params["target"] if self._params["target"] is not None else []
764777
weights = self._params["weights"] if self._params["weights"] is not None else []
@@ -767,21 +780,29 @@ def train_columns(self) -> list[str]:
767780

768781
@property
769782
def target_columns(self) -> list[str]:
770-
"""Target column names."""
783+
"""\ingroup Py_ML
784+
Target column names.
785+
"""
771786
if self._internal is None:
772787
return self._params["target"] if self._params["target"] is not None else []
773788
return self._internal.target_columns
774789

775790
@property
776791
def weights_column(self) -> str:
777-
"""Weights column name, or empty string if not set."""
792+
"""
793+
\ingroup Py_ML
794+
Weights column name, or empty string if not set.
795+
"""
778796
if self._internal is None:
779797
return self._params["weights"] if self._params["weights"] is not None else ""
780798
return self._internal.weights_column
781799

782800
@property
783801
def num_batches(self) -> int:
784-
"""Total number of batches in this split for one epoch."""
802+
"""
803+
\ingroup Py_ML
804+
Total number of batches in this split for one epoch.
805+
"""
785806
if self._internal is None:
786807
raise RuntimeError(
787808
"num_batches is available after the first call to "
@@ -793,7 +814,10 @@ def num_batches(self) -> int:
793814

794815
@property
795816
def last_batch_no_of_rows(self) -> int:
796-
"""Number of rows in the last (remainder) batch, 0 if no remainder."""
817+
"""
818+
\ingroup Py_ML
819+
Number of rows in the last (remainder) batch, 0 if no remainder.
820+
"""
797821
if self._internal is None:
798822
raise RuntimeError(
799823
"last_batch_no_of_rows is available after the first call to "

bindings/pyroot/pythonizations/python/ROOT/_pythonization/dataloader.md

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,9 @@ import ROOT
3838
# Open a ROOT file and create an RDataFrame
3939
rdf = ROOT.RDataFrame("events", "file.root")
4040
41-
# Define a Python callback to compute a new variable
42-
def invariant_mass(E: float, p: float) -> float:
43-
return math.sqrt(E**2 - p**2)
44-
4541
# Apply selections and compute derived features
4642
rdf = rdf.Filter("nMuons >= 2") \
47-
.Define("inv_mass", invariant_mass, ["E", "p"])
43+
.Define("inv_mass", "sqrt(E*E - p*p)")
4844
~~~
4945

5046
Then pass your `RDataFrame` to `RDataLoader`:
@@ -138,7 +134,7 @@ dl = RDataLoader(
138134
# events with fewer than 10 jets are zero-padded
139135
~~~
140136

141-
\warning Every RVec column in `columns` must appear in `max_vec_sizes`.
137+
\warning Every vector column in `columns` must appear in `max_vec_sizes`.
142138

143139
## Iterating Batches
144140

@@ -212,6 +208,14 @@ train, val = train_val.train_test_split(test_size=0.176)
212208

213209
## Advanced Features
214210

211+
### Eager loading
212+
213+
By default the loader reads data lazily, one chunk of data at a time. For small datasets that fit in memory and will be iterated many times, eager loading pays a one-time cost at construction and then serves batches every epoch from memory:
214+
215+
~~~{.py}
216+
dl = RDataLoader(rdf, batch_size=256, load_eager=True)
217+
~~~
218+
215219
### Resampling
216220

217221
Correct class imbalance by oversampling the minority or undersampling the majority. You can do this by passing two RDataFrames:
@@ -244,33 +248,3 @@ dl = RDataLoader(rdf,
244248
for X, y, w in dl.as_torch():
245249
loss = (loss_fn(model(X), y) * w).mean()
246250
~~~
247-
248-
### Eager loading
249-
250-
By default the loader reads data lazily, one chunk of data at a time. For small datasets that fit in memory and will be iterated many times, eager loading pays a one-time cost at construction and then serves every epoch from memory:
251-
252-
~~~{.py}
253-
dl = RDataLoader(rdf, batch_size=256, load_eager=True)
254-
~~~
255-
256-
## API Reference
257-
258-
### RDataLoader(rdataframes, ...)
259-
260-
| Argument | Type | Default | Description |
261-
|---|---|---|---|
262-
| `rdataframes` | `RDF \| list` | - | One or more RDataFrames to load from |
263-
| `batch_size` | `int` | `64` | Number of events per batch |
264-
| `batches_in_memory` | `int` | `10` | Shuffle buffer size in batches |
265-
| `columns` | `list[str]` | `None` | Branches to load - all if not given |
266-
| `max_vec_sizes` | `dict` | `None` | Max size per RVec column |
267-
| `vec_padding` | `float` | `0.0` | Pad value for short RVec entries |
268-
| `target` | `str \| list` | `None` | Label column(s) - returned as `y` |
269-
| `weights` | `str` | `""` | Event weight column - returned as `w` |
270-
| `shuffle` | `bool` | `True` | Randomise event order |
271-
| `drop_remainder` | `bool` | `True` | Drop last incomplete batch |
272-
| `set_seed` | `int` | `0` | RNG seed - 0 means random |
273-
| `load_eager` | `bool` | `False` | Load full dataset into RAM |
274-
| `sampling_type` | `str` | `""` | `"oversampling"` or `"undersampling"` |
275-
| `sampling_ratio` | `float` | `1.0` | Minority/majority ratio after resampling |
276-
| `replacement` | `bool` | `False` | Undersampling with replacement |

0 commit comments

Comments
 (0)