Skip to content

Commit 1288eff

Browse files
committed
fixed attribute error
1 parent f0da01e commit 1288eff

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,7 @@ def __init__(
664664

665665
# Re-initialize internal Orbax manager with MaxText's Grain handler
666666
# pylint: disable=access-member-before-definition
667+
# pytype: disable=attribute-error
667668
if self._checkpoint_manager is not None:
668669
root_directory = self._checkpoint_manager.directory
669670

@@ -684,6 +685,7 @@ def __init__(
684685
item_handlers=item_handlers,
685686
options=options,
686687
)
688+
# pytype: enable=attribute-error
687689
# pylint: enable=access-member-before-definition
688690

689691
def save(

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -378,13 +378,11 @@ def _log_metrics(self, loss, step=None, additional_metrics=None, **kwargs):
378378
tflops_per_sec = None
379379
if step_time_delta is not None and step_time_delta > 0:
380380
tflops_per_sec = self._tflops_combined / step_time_delta
381-
tflops_metrics.update(
382-
{
383-
"perf/per_device_tflops_per_sec": tflops_per_sec,
384-
"perf/per_device_tflops_per_sec_student": self._tflops_student / step_time_delta,
385-
"perf/per_device_tflops_per_sec_teacher": self._tflops_teacher / step_time_delta,
386-
}
387-
)
381+
tflops_metrics.update({
382+
"perf/per_device_tflops_per_sec": tflops_per_sec,
383+
"perf/per_device_tflops_per_sec_student": self._tflops_student / step_time_delta,
384+
"perf/per_device_tflops_per_sec_teacher": self._tflops_teacher / step_time_delta,
385+
})
388386
for name, value in tflops_metrics.items():
389387
self.metrics_logger.log(self.metrics_prefix, name, value, self._mode, step)
390388

0 commit comments

Comments
 (0)