diff --git a/src/maxtext/trainers/post_train/distillation/distillation_utils.py b/src/maxtext/trainers/post_train/distillation/distillation_utils.py index 71a63c1ce2..00a2ab4418 100644 --- a/src/maxtext/trainers/post_train/distillation/distillation_utils.py +++ b/src/maxtext/trainers/post_train/distillation/distillation_utils.py @@ -248,6 +248,9 @@ def compute_loss( "distill/teacher_loss": teacher_hard_loss, "distill/out_proj_feature_loss": feature_loss, "distill/total_loss": total_loss, + "distill/temperature": self.temperature, + "distill/alpha": self.alpha, + "distill/beta_feature": self.beta_feature, } return total_loss, metrics diff --git a/src/maxtext/trainers/post_train/distillation/train_distill.py b/src/maxtext/trainers/post_train/distillation/train_distill.py index 85eb045bfe..7f5541b69a 100644 --- a/src/maxtext/trainers/post_train/distillation/train_distill.py +++ b/src/maxtext/trainers/post_train/distillation/train_distill.py @@ -464,8 +464,7 @@ def labels_fn(targets, targets_segmentation=None, **kwargs): ) # 4. Optimizer & Config - total_updates = student_config.steps // student_config.gradient_accumulation_steps - optimizer = get_distillation_optimizer(student_config, total_updates) + optimizer = get_distillation_optimizer(student_config, student_config.steps) checkpointing_options = checkpoint.CheckpointManagerOptions( save_interval_steps=student_config.checkpoint_period, diff --git a/tests/post_training/unit/train_distill_test.py b/tests/post_training/unit/train_distill_test.py index e059986162..9b4249a8f7 100644 --- a/tests/post_training/unit/train_distill_test.py +++ b/tests/post_training/unit/train_distill_test.py @@ -399,6 +399,9 @@ def _test_monitored_strategy(self, sft_mode: bool): "distill/teacher_loss", "distill/out_proj_feature_loss", "distill/total_loss", + "distill/temperature", + "distill/alpha", + "distill/beta_feature", ] for key in expected_keys: self.assertIn(key, metrics)