From 4f6d05e3a03ca7f6bbfb840c65b106c9de226514 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 18 Mar 2026 21:51:36 +0000 Subject: [PATCH] fix optimizer number of steps --- .../trainers/post_train/distillation/distillation_utils.py | 3 +++ src/maxtext/trainers/post_train/distillation/train_distill.py | 3 +-- tests/post_training/unit/train_distill_test.py | 3 +++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 71a63c1ce2..00a2ab4418 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -248,6 +248,9 @@ def compute_loss( "distill/teacher_loss": teacher_hard_loss, "distill/out_proj_feature_loss": feature_loss, "distill/total_loss": total_loss, + "distill/temperature": self.temperature, + "distill/alpha": self.alpha, + "distill/beta_feature": self.beta_feature, } return total_loss, metrics diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 85eb045bfe..7f5541b69a 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -464,8 +464,7 @@ def labels_fn(targets, targets_segmentation=None, **kwargs): ) # 4. Optimizer & Config - total_updates = student_config.steps // student_config.gradient_accumulation_steps - optimizer = get_distillation_optimizer(student_config, total_updates) + optimizer = get_distillation_optimizer(student_config, student_config.steps) checkpointing_options = checkpoint.CheckpointManagerOptions( save_interval_steps=student_config.checkpoint_period, diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index e059986162..9b4249a8f7 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -399,6 +399,9 @@ def _test_monitored_strategy(self, sft_mode: bool): "distill/teacher_loss", "distill/out_proj_feature_loss", "distill/total_loss", + "distill/temperature", + "distill/alpha", + "distill/beta_feature", ] for key in expected_keys: self.assertIn(key, metrics)