Skip to content

Commit df0a92f

Browse files
committed
updated formatting
1 parent 1288eff commit df0a92f

1 file changed

Lines changed: 7 additions & 5 deletions

File tree

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,11 +378,13 @@ def _log_metrics(self, loss, step=None, additional_metrics=None, **kwargs):
378378
tflops_per_sec = None
379379
if step_time_delta is not None and step_time_delta > 0:
380380
tflops_per_sec = self._tflops_combined / step_time_delta
381-
tflops_metrics.update({
382-
"perf/per_device_tflops_per_sec": tflops_per_sec,
383-
"perf/per_device_tflops_per_sec_student": self._tflops_student / step_time_delta,
384-
"perf/per_device_tflops_per_sec_teacher": self._tflops_teacher / step_time_delta,
385-
})
381+
tflops_metrics.update(
382+
{
383+
"perf/per_device_tflops_per_sec": tflops_per_sec,
384+
"perf/per_device_tflops_per_sec_student": self._tflops_student / step_time_delta,
385+
"perf/per_device_tflops_per_sec_teacher": self._tflops_teacher / step_time_delta,
386+
}
387+
)
386388
for name, value in tflops_metrics.items():
387389
self.metrics_logger.log(self.metrics_prefix, name, value, self._mode, step)
388390

0 commit comments

Comments
 (0)