Skip to content

Commit 24f72e0

Browse files
Merge pull request #3798 from AI-Hypercomputer:agagik-flops-fix
PiperOrigin-RevId: 908936231
2 parents 0d4af87 + 0f6874f commit 24f72e0

2 files changed

Lines changed: 36 additions & 5 deletions

File tree

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -343,9 +343,14 @@ def _eval_step(self, model, inputs):
343343
labels = self.strategy.create_labels(inputs["targets"], targets_segmentation=inputs.get("targets_segmentation", None))
344344
return self.strategy.compute_eval_loss(student_output, labels)
345345

346-
def _log_metrics(self, loss, step=None, step_time_delta=None, additional_metrics=None):
347-
"""Adds per-device TFLOPs (and per-sec variants) to the standard Tunix metrics."""
348-
super()._log_metrics(loss=loss, step=step, step_time_delta=step_time_delta, additional_metrics=additional_metrics)
346+
def _log_metrics(self, loss, step=None, additional_metrics=None, **kwargs):
347+
"""Adds per-device TFLOPs to the standard Tunix metrics.
348+
349+
`step_time_delta` is consumed via **kwargs so this override works against
350+
older tunix versions whose base `_log_metrics` does not accept it.
351+
"""
352+
super()._log_metrics(loss=loss, step=step, additional_metrics=additional_metrics, **kwargs)
353+
step_time_delta = kwargs.get("step_time_delta")
349354

350355
tflops_metrics = {
351356
"perf/per_device_tflops": self._tflops_combined,

tests/post_training/unit/train_distill_test.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,8 @@ def side_effect(self, *args, **kwargs):
891891
self.assertTrue(any(c == "1" or c.endswith("1") for c in checkpoints), f"Checkpoint 1 not found in {checkpoints}")
892892
self.assertTrue(any(c == "2" or c.endswith("2") for c in checkpoints), f"Checkpoint 2 not found in {checkpoints}")
893893

894-
def test_checkpointing_and_resume(self):
894+
@mock.patch.object(distillation_utils, "calculate_distillation_tflops_per_device", return_value=(0.0, 0.0, 0.0))
895+
def test_checkpointing_and_resume(self, _mock_tflops):
895896
"""Trains a few steps, saves a checkpoint, and resumes from it."""
896897

897898
# 1. Setup minimal dummy model and models bundle
@@ -941,6 +942,8 @@ def __call__(self, input_tokens, **kwargs):
941942
strategy=strategy,
942943
optimizer=optimizer1,
943944
training_config=train_config,
945+
student_config=mock.Mock(),
946+
teacher_config=mock.Mock(),
944947
)
945948
trainer1._lora_enabled = False
946949
trainer1.is_managed_externally = True
@@ -989,6 +992,8 @@ def __call__(self, input_tokens, **kwargs):
989992
strategy=strategy,
990993
optimizer=optimizer2,
991994
training_config=train_config,
995+
student_config=mock.Mock(),
996+
teacher_config=mock.Mock(),
992997
)
993998
trainer2._lora_enabled = False
994999

@@ -1083,8 +1088,17 @@ def test_main_offline_mode_skips_teacher_loading(
10831088
mock_student_cfg.save_checkpoint_on_completion = False
10841089
mock_student_cfg.logical_axis_rules = []
10851090

1091+
# main() validates that student/teacher share batch shape — set explicit
1092+
# equal scalars on both mocks so the assertion passes.
1093+
mock_student_cfg.per_device_batch_size = 1
1094+
mock_student_cfg.max_target_length = 16
1095+
mock_student_cfg.gradient_accumulation_steps = 1
1096+
10861097
mock_teacher_cfg = mock.Mock()
10871098
mock_teacher_cfg.vocab_size = 32000
1099+
mock_teacher_cfg.per_device_batch_size = 1
1100+
mock_teacher_cfg.max_target_length = 16
1101+
mock_teacher_cfg.gradient_accumulation_steps = 1
10881102
mock_pyconfig_init.side_effect = [mock_global, mock_student_cfg, mock_teacher_cfg]
10891103

10901104
# 2. Model Loading
@@ -1181,8 +1195,17 @@ def test_main_online_mode_loads_teacher(
11811195
mock_student_cfg.save_checkpoint_on_completion = False
11821196
mock_student_cfg.logical_axis_rules = []
11831197

1198+
# main() validates that student/teacher share batch shape — set explicit
1199+
# equal scalars on both mocks so the assertion passes.
1200+
mock_student_cfg.per_device_batch_size = 1
1201+
mock_student_cfg.max_target_length = 16
1202+
mock_student_cfg.gradient_accumulation_steps = 1
1203+
11841204
mock_teacher_cfg = mock.Mock()
11851205
mock_teacher_cfg.vocab_size = 32000
1206+
mock_teacher_cfg.per_device_batch_size = 1
1207+
mock_teacher_cfg.max_target_length = 16
1208+
mock_teacher_cfg.gradient_accumulation_steps = 1
11861209
mock_pyconfig_init.side_effect = [mock_global, mock_student_cfg, mock_teacher_cfg]
11871210

11881211
mock_student_model = mock.Mock()
@@ -1206,7 +1229,8 @@ def test_main_online_mode_loads_teacher(
12061229
self.assertIs(model_bundle.student_model, mock_student_model)
12071230
self.assertIs(model_bundle.teacher_model, mock_teacher_model)
12081231

1209-
def test_student_freeze_param_filter(self):
1232+
@mock.patch.object(distillation_utils, "calculate_distillation_tflops_per_device", return_value=(0.0, 0.0, 0.0))
1233+
def test_student_freeze_param_filter(self, _mock_tflops):
12101234
"""Verifies that student_freeze_param_filter correctly freezes specified parameters."""
12111235

12121236
# 1. Setup a dummy model with multiple layers
@@ -1260,6 +1284,8 @@ def freeze_filter(path):
12601284
strategy=strategy,
12611285
optimizer=optax.sgd(0.1),
12621286
training_config=train_config,
1287+
student_config=mock.Mock(),
1288+
teacher_config=mock.Mock(),
12631289
student_freeze_param_filter=freeze_filter,
12641290
)
12651291
trainer._lora_enabled = False

0 commit comments

Comments
 (0)