11import time
2+ from math import isfinite
23from typing import Any
34
45import lightning .pytorch as pl
@@ -17,7 +18,15 @@ class TrainingProgressCallback(pl.Callback):
1718 - C{train/epoch_progress_percent}: Percentage of current epoch completed
1819 - C{train/epoch_duration_sec}: Time elapsed so far in current epoch (updated per batch)
1920 - C{train/epoch_completion_sec}: Total duration of completed training epoch in seconds
21+ - C{train/batch_total_sec}: Time spent processing one training batch
22+ - C{val/epoch_progress_percent}: Percentage of current validation epoch completed
23+ - C{val/epoch_duration_sec}: Time elapsed so far in current validation epoch
2024 - C{val/epoch_completion_sec}: Total duration of completed validation epoch in seconds
25+ - C{val/batch_total_sec}: Time spent processing one validation batch
26+ - C{test/epoch_progress_percent}: Percentage of current test epoch completed
27+ - C{test/epoch_duration_sec}: Time elapsed so far in current test epoch
28+ - C{test/epoch_completion_sec}: Total duration of completed test epoch in seconds
29+ - C{test/batch_total_sec}: Time spent processing one test batch
2130 """
2231
2332 def __init__ (self , log_every_n_batches : int = 1 ):
@@ -32,28 +41,78 @@ def __init__(self, log_every_n_batches: int = 1):
3241 self .log_every_n_batches = max (1 , log_every_n_batches )
3342 self ._train_epoch_start_time : float | None = None
3443 self ._val_epoch_start_time : float | None = None
44+ self ._test_epoch_start_time : float | None = None
45+ self ._train_batch_start_time : float | None = None
46+ self ._val_batch_start_time : float | None = None
47+ self ._test_batch_start_time : float | None = None
48+ self ._train_batch_step = 0
49+ self ._val_batch_step = 0
50+ self ._test_batch_step = 0
51+ self ._val_epoch_batch_count = 0
52+ self ._test_epoch_batch_count = 0
53+
54+ @staticmethod
55+ def _now () -> float :
56+ return time .perf_counter ()
57+
58+ @staticmethod
59+ def _elapsed (start_time : float | None ) -> float :
60+ if start_time is None :
61+ return 0.0
62+ return time .perf_counter () - start_time
63+
64+ @staticmethod
65+ def _total_batches (
66+ total_batches : float | list [int | float ],
67+ ) -> int :
68+ """Return the total number of batches across eval
69+ dataloaders.
70+ """
71+ if isinstance (total_batches , list ):
72+ return sum (
73+ int (batch_count )
74+ for batch_count in total_batches
75+ if isfinite (batch_count )
76+ )
77+ if not isfinite (total_batches ):
78+ return 0
79+ return int (total_batches )
3580
3681 @override
3782 def on_train_epoch_start (
3883 self ,
3984 trainer : pl .Trainer ,
4085 pl_module : "lxt.LuxonisLightningModule" ,
4186 ) -> None :
42- self ._train_epoch_start_time = time . time ()
87+ self ._train_epoch_start_time = self . _now ()
4388
4489 if trainer .logger is None :
4590 logger .warning (
4691 "TrainingProgressCallback requires a logger to be configured."
4792 )
4893 return
4994
95+ # Keep train progress/timing metrics on a cumulative batch axis.
96+ # `global_step` tracks optimizer steps, so with gradient
97+ # accumulation multiple train batches can collapse onto the same
98+ # step and stop being truly per-batch aligned.
5099 trainer .logger .log_metrics (
51100 {
52101 "train/epoch_progress_percent" : 0.0 ,
53102 },
54- step = trainer . global_step ,
103+ step = self . _train_batch_step ,
55104 )
56105
106+ @override
107+ def on_train_batch_start (
108+ self ,
109+ trainer : pl .Trainer ,
110+ pl_module : "lxt.LuxonisLightningModule" ,
111+ batch : Any ,
112+ batch_idx : int ,
113+ ) -> None :
114+ self ._train_batch_start_time = self ._now ()
115+
57116 @rank_zero_only
58117 @override
59118 def on_train_batch_end (
@@ -64,11 +123,13 @@ def on_train_batch_end(
64123 batch : Any ,
65124 batch_idx : int ,
66125 ) -> None :
126+ self ._train_batch_step += 1
127+
67128 if trainer .logger is None :
68129 return
69130
70131 # Log every N batches to reduce overhead
71- if (batch_idx + 1 ) % self . log_every_n_batches != 0 :
132+ if not self . _should_log_batch (batch_idx + 1 ):
72133 return
73134
74135 total_batches = trainer .num_training_batches
@@ -79,18 +140,16 @@ def on_train_batch_end(
79140 else 0.0
80141 )
81142
82- epoch_duration = (
83- time .time () - self ._train_epoch_start_time
84- if self ._train_epoch_start_time is not None
85- else 0.0
86- )
143+ epoch_duration = self ._elapsed (self ._train_epoch_start_time )
144+ batch_total = self ._elapsed (self ._train_batch_start_time )
87145
88146 trainer .logger .log_metrics (
89147 {
90148 "train/epoch_progress_percent" : progress_percent ,
91149 "train/epoch_duration_sec" : epoch_duration ,
150+ "train/batch_total_sec" : batch_total ,
92151 },
93- step = trainer . global_step ,
152+ step = self . _train_batch_step ,
94153 )
95154
96155 @rank_zero_only
@@ -103,18 +162,14 @@ def on_train_epoch_end(
103162 if trainer .logger is None :
104163 return
105164
106- epoch_duration = (
107- time .time () - self ._train_epoch_start_time
108- if self ._train_epoch_start_time is not None
109- else 0.0
110- )
165+ epoch_duration = self ._elapsed (self ._train_epoch_start_time )
111166
112167 trainer .logger .log_metrics (
113168 {
114169 "train/epoch_completion_sec" : epoch_duration ,
115170 "train/epoch_progress_percent" : 100.0 ,
116171 },
117- step = trainer . current_epoch ,
172+ step = self . _train_batch_step ,
118173 )
119174
120175 @override
@@ -123,7 +178,72 @@ def on_validation_epoch_start(
123178 trainer : pl .Trainer ,
124179 pl_module : "lxt.LuxonisLightningModule" ,
125180 ) -> None :
126- self ._val_epoch_start_time = time .time ()
181+ self ._val_epoch_start_time = self ._now ()
182+ self ._val_epoch_batch_count = 0
183+
184+ if trainer .sanity_checking or trainer .logger is None :
185+ return
186+
187+ trainer .logger .log_metrics (
188+ {"val/epoch_progress_percent" : 0.0 },
189+ step = self ._val_batch_step ,
190+ )
191+
192+ @override
193+ def on_validation_batch_start (
194+ self ,
195+ trainer : pl .Trainer ,
196+ pl_module : "lxt.LuxonisLightningModule" ,
197+ batch : Any ,
198+ batch_idx : int ,
199+ dataloader_idx : int = 0 ,
200+ ) -> None :
201+ if trainer .sanity_checking :
202+ return
203+
204+ self ._val_batch_start_time = self ._now ()
205+
206+ @rank_zero_only
207+ @override
208+ def on_validation_batch_end (
209+ self ,
210+ trainer : pl .Trainer ,
211+ pl_module : "lxt.LuxonisLightningModule" ,
212+ outputs : STEP_OUTPUT ,
213+ batch : Any ,
214+ batch_idx : int ,
215+ dataloader_idx : int = 0 ,
216+ ) -> None :
217+ if trainer .sanity_checking :
218+ return
219+
220+ self ._val_epoch_batch_count += 1
221+ self ._val_batch_step += 1
222+
223+ if trainer .logger is None :
224+ return
225+
226+ if not self ._should_log_batch (self ._val_epoch_batch_count ):
227+ return
228+
229+ total_batches = self ._total_batches (trainer .num_val_batches )
230+ progress_percent = (
231+ (self ._val_epoch_batch_count / total_batches ) * 100
232+ if total_batches > 0
233+ else 0.0
234+ )
235+ epoch_duration = self ._elapsed (self ._val_epoch_start_time )
236+
237+ trainer .logger .log_metrics (
238+ {
239+ "val/batch_total_sec" : self ._elapsed (
240+ self ._val_batch_start_time
241+ ),
242+ "val/epoch_progress_percent" : progress_percent ,
243+ "val/epoch_duration_sec" : epoch_duration ,
244+ },
245+ step = self ._val_batch_step ,
246+ )
127247
128248 @rank_zero_only
129249 @override
@@ -135,13 +255,116 @@ def on_validation_epoch_end(
135255 if trainer .sanity_checking or trainer .logger is None :
136256 return
137257
138- epoch_duration = (
139- time .time () - self ._val_epoch_start_time
140- if self ._val_epoch_start_time is not None
258+ epoch_duration = self ._elapsed (self ._val_epoch_start_time )
259+
260+ if self ._val_epoch_batch_count > 0 and not self ._should_log_batch (
261+ self ._val_epoch_batch_count
262+ ):
263+ trainer .logger .log_metrics (
264+ {
265+ "val/epoch_progress_percent" : 100.0 ,
266+ "val/epoch_duration_sec" : epoch_duration ,
267+ },
268+ step = self ._val_batch_step ,
269+ )
270+ trainer .logger .log_metrics (
271+ {"val/epoch_completion_sec" : epoch_duration },
272+ step = trainer .current_epoch ,
273+ )
274+
275+ @override
276+ def on_test_epoch_start (
277+ self ,
278+ trainer : pl .Trainer ,
279+ pl_module : "lxt.LuxonisLightningModule" ,
280+ ) -> None :
281+ self ._test_epoch_start_time = self ._now ()
282+ self ._test_epoch_batch_count = 0
283+
284+ if trainer .logger is None :
285+ return
286+
287+ trainer .logger .log_metrics (
288+ {"test/epoch_progress_percent" : 0.0 },
289+ step = self ._test_batch_step ,
290+ )
291+
292+ @override
293+ def on_test_batch_start (
294+ self ,
295+ trainer : pl .Trainer ,
296+ pl_module : "lxt.LuxonisLightningModule" ,
297+ batch : Any ,
298+ batch_idx : int ,
299+ dataloader_idx : int = 0 ,
300+ ) -> None :
301+ self ._test_batch_start_time = self ._now ()
302+
303+ @rank_zero_only
304+ @override
305+ def on_test_batch_end (
306+ self ,
307+ trainer : pl .Trainer ,
308+ pl_module : "lxt.LuxonisLightningModule" ,
309+ outputs : STEP_OUTPUT ,
310+ batch : Any ,
311+ batch_idx : int ,
312+ dataloader_idx : int = 0 ,
313+ ) -> None :
314+ self ._test_epoch_batch_count += 1
315+ self ._test_batch_step += 1
316+
317+ if trainer .logger is None :
318+ return
319+
320+ if not self ._should_log_batch (self ._test_epoch_batch_count ):
321+ return
322+
323+ total_batches = self ._total_batches (trainer .num_test_batches )
324+ progress_percent = (
325+ (self ._test_epoch_batch_count / total_batches ) * 100
326+ if total_batches > 0
141327 else 0.0
142328 )
329+ epoch_duration = self ._elapsed (self ._test_epoch_start_time )
143330
144331 trainer .logger .log_metrics (
145- {"val/epoch_completion_sec" : epoch_duration },
332+ {
333+ "test/batch_total_sec" : self ._elapsed (
334+ self ._test_batch_start_time
335+ ),
336+ "test/epoch_progress_percent" : progress_percent ,
337+ "test/epoch_duration_sec" : epoch_duration ,
338+ },
339+ step = self ._test_batch_step ,
340+ )
341+
342+ @rank_zero_only
343+ @override
344+ def on_test_epoch_end (
345+ self ,
346+ trainer : pl .Trainer ,
347+ pl_module : "lxt.LuxonisLightningModule" ,
348+ ) -> None :
349+ if trainer .logger is None :
350+ return
351+
352+ epoch_duration = self ._elapsed (self ._test_epoch_start_time )
353+
354+ if self ._test_epoch_batch_count > 0 and not self ._should_log_batch (
355+ self ._test_epoch_batch_count
356+ ):
357+ trainer .logger .log_metrics (
358+ {
359+ "test/epoch_progress_percent" : 100.0 ,
360+ "test/epoch_duration_sec" : epoch_duration ,
361+ },
362+ step = self ._test_batch_step ,
363+ )
364+ trainer .logger .log_metrics (
365+ {"test/epoch_completion_sec" : epoch_duration },
146366 step = trainer .current_epoch ,
147367 )
368+
369+ def _should_log_batch (self , seen_batches : int ) -> bool :
370+ return seen_batches % self .log_every_n_batches == 0
0 commit comments