Skip to content

Commit e68c353

Browse files
authored
Added per-batch timings, added eval and test timings (#390)
1 parent 8505b0a commit e68c353

4 files changed

Lines changed: 323 additions & 28 deletions

File tree

luxonis_train/callbacks/README.md

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,17 @@ Callback that publishes training progress and timing metrics.
198198

199199
**Published Metrics:**
200200

201-
| Metric Key | Description |
202-
| ------------------------------ | ------------------------------------------------------- |
203-
| `train/epoch_progress_percent` | Percentage (0-100) of current epoch completed |
204-
| `train/epoch_duration_sec` | Time elapsed so far in current epoch |
205-
| `train/epoch_completion_sec` | Total duration of completed training epoch in seconds |
206-
| `val/epoch_completion_sec` | Total duration of completed validation epoch in seconds |
201+
| Metric Key | Description |
202+
| ------------------------------ | -------------------------------------------------------- |
203+
| `train/batch_total_sec` | Time spent processing one training batch |
204+
| `train/epoch_progress_percent` | Percentage (0-100) of current epoch completed |
205+
| `train/epoch_duration_sec` | Time elapsed so far in current epoch |
206+
| `train/epoch_completion_sec` | Total duration of completed training epoch in seconds |
207+
| `val/batch_total_sec` | Time spent processing one validation batch |
208+
| `val/epoch_progress_percent` | Percentage (0-100) of current validation epoch completed |
209+
| `val/epoch_duration_sec` | Time elapsed so far in current validation epoch |
210+
| `val/epoch_completion_sec` | Total duration of completed validation epoch in seconds |
211+
| `test/batch_total_sec` | Time spent processing one test batch |
212+
| `test/epoch_progress_percent` | Percentage (0-100) of current test epoch completed |
213+
| `test/epoch_duration_sec` | Time elapsed so far in current test epoch |
214+
| `test/epoch_completion_sec` | Total duration of completed test epoch in seconds |

luxonis_train/callbacks/training_progress_callback.py

Lines changed: 243 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time
2+
from math import isfinite
23
from typing import Any
34

45
import lightning.pytorch as pl
@@ -17,7 +18,15 @@ class TrainingProgressCallback(pl.Callback):
1718
- C{train/epoch_progress_percent}: Percentage of current epoch completed
1819
- C{train/epoch_duration_sec}: Time elapsed so far in current epoch (updated per batch)
1920
- C{train/epoch_completion_sec}: Total duration of completed training epoch in seconds
21+
- C{train/batch_total_sec}: Time spent processing one training batch
22+
- C{val/epoch_progress_percent}: Percentage of current validation epoch completed
23+
- C{val/epoch_duration_sec}: Time elapsed so far in current validation epoch
2024
- C{val/epoch_completion_sec}: Total duration of completed validation epoch in seconds
25+
- C{val/batch_total_sec}: Time spent processing one validation batch
26+
- C{test/epoch_progress_percent}: Percentage of current test epoch completed
27+
- C{test/epoch_duration_sec}: Time elapsed so far in current test epoch
28+
- C{test/epoch_completion_sec}: Total duration of completed test epoch in seconds
29+
- C{test/batch_total_sec}: Time spent processing one test batch
2130
"""
2231

2332
def __init__(self, log_every_n_batches: int = 1):
@@ -32,28 +41,78 @@ def __init__(self, log_every_n_batches: int = 1):
3241
self.log_every_n_batches = max(1, log_every_n_batches)
3342
self._train_epoch_start_time: float | None = None
3443
self._val_epoch_start_time: float | None = None
44+
self._test_epoch_start_time: float | None = None
45+
self._train_batch_start_time: float | None = None
46+
self._val_batch_start_time: float | None = None
47+
self._test_batch_start_time: float | None = None
48+
self._train_batch_step = 0
49+
self._val_batch_step = 0
50+
self._test_batch_step = 0
51+
self._val_epoch_batch_count = 0
52+
self._test_epoch_batch_count = 0
53+
54+
@staticmethod
55+
def _now() -> float:
56+
return time.perf_counter()
57+
58+
@staticmethod
59+
def _elapsed(start_time: float | None) -> float:
60+
if start_time is None:
61+
return 0.0
62+
return time.perf_counter() - start_time
63+
64+
@staticmethod
65+
def _total_batches(
66+
total_batches: float | list[int | float],
67+
) -> int:
68+
"""Return the total number of batches across eval
69+
dataloaders.
70+
"""
71+
if isinstance(total_batches, list):
72+
return sum(
73+
int(batch_count)
74+
for batch_count in total_batches
75+
if isfinite(batch_count)
76+
)
77+
if not isfinite(total_batches):
78+
return 0
79+
return int(total_batches)
3580

3681
@override
3782
def on_train_epoch_start(
3883
self,
3984
trainer: pl.Trainer,
4085
pl_module: "lxt.LuxonisLightningModule",
4186
) -> None:
42-
self._train_epoch_start_time = time.time()
87+
self._train_epoch_start_time = self._now()
4388

4489
if trainer.logger is None:
4590
logger.warning(
4691
"TrainingProgressCallback requires a logger to be configured."
4792
)
4893
return
4994

95+
# Keep train progress/timing metrics on a cumulative batch axis.
96+
# `global_step` tracks optimizer steps, so with gradient
97+
# accumulation multiple train batches can collapse onto the same
98+
# step and stop being truly per-batch aligned.
5099
trainer.logger.log_metrics(
51100
{
52101
"train/epoch_progress_percent": 0.0,
53102
},
54-
step=trainer.global_step,
103+
step=self._train_batch_step,
55104
)
56105

106+
@override
107+
def on_train_batch_start(
108+
self,
109+
trainer: pl.Trainer,
110+
pl_module: "lxt.LuxonisLightningModule",
111+
batch: Any,
112+
batch_idx: int,
113+
) -> None:
114+
self._train_batch_start_time = self._now()
115+
57116
@rank_zero_only
58117
@override
59118
def on_train_batch_end(
@@ -64,11 +123,13 @@ def on_train_batch_end(
64123
batch: Any,
65124
batch_idx: int,
66125
) -> None:
126+
self._train_batch_step += 1
127+
67128
if trainer.logger is None:
68129
return
69130

70131
# Log every N batches to reduce overhead
71-
if (batch_idx + 1) % self.log_every_n_batches != 0:
132+
if not self._should_log_batch(batch_idx + 1):
72133
return
73134

74135
total_batches = trainer.num_training_batches
@@ -79,18 +140,16 @@ def on_train_batch_end(
79140
else 0.0
80141
)
81142

82-
epoch_duration = (
83-
time.time() - self._train_epoch_start_time
84-
if self._train_epoch_start_time is not None
85-
else 0.0
86-
)
143+
epoch_duration = self._elapsed(self._train_epoch_start_time)
144+
batch_total = self._elapsed(self._train_batch_start_time)
87145

88146
trainer.logger.log_metrics(
89147
{
90148
"train/epoch_progress_percent": progress_percent,
91149
"train/epoch_duration_sec": epoch_duration,
150+
"train/batch_total_sec": batch_total,
92151
},
93-
step=trainer.global_step,
152+
step=self._train_batch_step,
94153
)
95154

96155
@rank_zero_only
@@ -103,18 +162,14 @@ def on_train_epoch_end(
103162
if trainer.logger is None:
104163
return
105164

106-
epoch_duration = (
107-
time.time() - self._train_epoch_start_time
108-
if self._train_epoch_start_time is not None
109-
else 0.0
110-
)
165+
epoch_duration = self._elapsed(self._train_epoch_start_time)
111166

112167
trainer.logger.log_metrics(
113168
{
114169
"train/epoch_completion_sec": epoch_duration,
115170
"train/epoch_progress_percent": 100.0,
116171
},
117-
step=trainer.current_epoch,
172+
step=self._train_batch_step,
118173
)
119174

120175
@override
@@ -123,7 +178,72 @@ def on_validation_epoch_start(
123178
trainer: pl.Trainer,
124179
pl_module: "lxt.LuxonisLightningModule",
125180
) -> None:
126-
self._val_epoch_start_time = time.time()
181+
self._val_epoch_start_time = self._now()
182+
self._val_epoch_batch_count = 0
183+
184+
if trainer.sanity_checking or trainer.logger is None:
185+
return
186+
187+
trainer.logger.log_metrics(
188+
{"val/epoch_progress_percent": 0.0},
189+
step=self._val_batch_step,
190+
)
191+
192+
@override
193+
def on_validation_batch_start(
194+
self,
195+
trainer: pl.Trainer,
196+
pl_module: "lxt.LuxonisLightningModule",
197+
batch: Any,
198+
batch_idx: int,
199+
dataloader_idx: int = 0,
200+
) -> None:
201+
if trainer.sanity_checking:
202+
return
203+
204+
self._val_batch_start_time = self._now()
205+
206+
@rank_zero_only
207+
@override
208+
def on_validation_batch_end(
209+
self,
210+
trainer: pl.Trainer,
211+
pl_module: "lxt.LuxonisLightningModule",
212+
outputs: STEP_OUTPUT,
213+
batch: Any,
214+
batch_idx: int,
215+
dataloader_idx: int = 0,
216+
) -> None:
217+
if trainer.sanity_checking:
218+
return
219+
220+
self._val_epoch_batch_count += 1
221+
self._val_batch_step += 1
222+
223+
if trainer.logger is None:
224+
return
225+
226+
if not self._should_log_batch(self._val_epoch_batch_count):
227+
return
228+
229+
total_batches = self._total_batches(trainer.num_val_batches)
230+
progress_percent = (
231+
(self._val_epoch_batch_count / total_batches) * 100
232+
if total_batches > 0
233+
else 0.0
234+
)
235+
epoch_duration = self._elapsed(self._val_epoch_start_time)
236+
237+
trainer.logger.log_metrics(
238+
{
239+
"val/batch_total_sec": self._elapsed(
240+
self._val_batch_start_time
241+
),
242+
"val/epoch_progress_percent": progress_percent,
243+
"val/epoch_duration_sec": epoch_duration,
244+
},
245+
step=self._val_batch_step,
246+
)
127247

128248
@rank_zero_only
129249
@override
@@ -135,13 +255,116 @@ def on_validation_epoch_end(
135255
if trainer.sanity_checking or trainer.logger is None:
136256
return
137257

138-
epoch_duration = (
139-
time.time() - self._val_epoch_start_time
140-
if self._val_epoch_start_time is not None
258+
epoch_duration = self._elapsed(self._val_epoch_start_time)
259+
260+
if self._val_epoch_batch_count > 0 and not self._should_log_batch(
261+
self._val_epoch_batch_count
262+
):
263+
trainer.logger.log_metrics(
264+
{
265+
"val/epoch_progress_percent": 100.0,
266+
"val/epoch_duration_sec": epoch_duration,
267+
},
268+
step=self._val_batch_step,
269+
)
270+
trainer.logger.log_metrics(
271+
{"val/epoch_completion_sec": epoch_duration},
272+
step=trainer.current_epoch,
273+
)
274+
275+
@override
276+
def on_test_epoch_start(
277+
self,
278+
trainer: pl.Trainer,
279+
pl_module: "lxt.LuxonisLightningModule",
280+
) -> None:
281+
self._test_epoch_start_time = self._now()
282+
self._test_epoch_batch_count = 0
283+
284+
if trainer.logger is None:
285+
return
286+
287+
trainer.logger.log_metrics(
288+
{"test/epoch_progress_percent": 0.0},
289+
step=self._test_batch_step,
290+
)
291+
292+
@override
293+
def on_test_batch_start(
294+
self,
295+
trainer: pl.Trainer,
296+
pl_module: "lxt.LuxonisLightningModule",
297+
batch: Any,
298+
batch_idx: int,
299+
dataloader_idx: int = 0,
300+
) -> None:
301+
self._test_batch_start_time = self._now()
302+
303+
@rank_zero_only
304+
@override
305+
def on_test_batch_end(
306+
self,
307+
trainer: pl.Trainer,
308+
pl_module: "lxt.LuxonisLightningModule",
309+
outputs: STEP_OUTPUT,
310+
batch: Any,
311+
batch_idx: int,
312+
dataloader_idx: int = 0,
313+
) -> None:
314+
self._test_epoch_batch_count += 1
315+
self._test_batch_step += 1
316+
317+
if trainer.logger is None:
318+
return
319+
320+
if not self._should_log_batch(self._test_epoch_batch_count):
321+
return
322+
323+
total_batches = self._total_batches(trainer.num_test_batches)
324+
progress_percent = (
325+
(self._test_epoch_batch_count / total_batches) * 100
326+
if total_batches > 0
141327
else 0.0
142328
)
329+
epoch_duration = self._elapsed(self._test_epoch_start_time)
143330

144331
trainer.logger.log_metrics(
145-
{"val/epoch_completion_sec": epoch_duration},
332+
{
333+
"test/batch_total_sec": self._elapsed(
334+
self._test_batch_start_time
335+
),
336+
"test/epoch_progress_percent": progress_percent,
337+
"test/epoch_duration_sec": epoch_duration,
338+
},
339+
step=self._test_batch_step,
340+
)
341+
342+
@rank_zero_only
343+
@override
344+
def on_test_epoch_end(
345+
self,
346+
trainer: pl.Trainer,
347+
pl_module: "lxt.LuxonisLightningModule",
348+
) -> None:
349+
if trainer.logger is None:
350+
return
351+
352+
epoch_duration = self._elapsed(self._test_epoch_start_time)
353+
354+
if self._test_epoch_batch_count > 0 and not self._should_log_batch(
355+
self._test_epoch_batch_count
356+
):
357+
trainer.logger.log_metrics(
358+
{
359+
"test/epoch_progress_percent": 100.0,
360+
"test/epoch_duration_sec": epoch_duration,
361+
},
362+
step=self._test_batch_step,
363+
)
364+
trainer.logger.log_metrics(
365+
{"test/epoch_completion_sec": epoch_duration},
146366
step=trainer.current_epoch,
147367
)
368+
369+
def _should_log_batch(self, seen_batches: int) -> bool:
370+
return seen_batches % self.log_every_n_batches == 0

luxonis_train/lightning/luxonis_lightning.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,10 +1025,18 @@ def get_mlflow_logging_keys(self) -> dict[str, list[str]]:
10251025
elif callback.name == "TrainingProgressCallback":
10261026
metric_keys.update(
10271027
{
1028+
"train/batch_total_sec",
10281029
"train/epoch_progress_percent",
10291030
"train/epoch_duration_sec",
10301031
"train/epoch_completion_sec",
1032+
"val/batch_total_sec",
1033+
"val/epoch_progress_percent",
1034+
"val/epoch_duration_sec",
10311035
"val/epoch_completion_sec",
1036+
"test/batch_total_sec",
1037+
"test/epoch_progress_percent",
1038+
"test/epoch_duration_sec",
1039+
"test/epoch_completion_sec",
10321040
}
10331041
)
10341042

0 commit comments

Comments
 (0)