Skip to content

Commit 3b4244e

Browse files
Merge pull request #3548 from AI-Hypercomputer:vladk/distill-freeze
PiperOrigin-RevId: 893589562
2 parents cdc587f + d205e16 commit 3b4244e

3 files changed

Lines changed: 135 additions & 5 deletions

File tree

src/maxtext/configs/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,13 @@ class Distillation(BaseModel):
11461146
distill_beta: float = Field(0.0, description="Weight for the feature loss component. Use 0.0 to disable")
11471147
distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.")
11481148

1149+
# --- Distillation freezing filter --
1150+
student_params_to_update: None | list = Field(
1151+
None,
1152+
description="a list of model param name templates to finetune in the student model. "
1153+
"The other parameters will be frozen if this attribute is non empty)",
1154+
)
1155+
11491156

11501157
class TrainingLoop(BaseModel):
11511158
"""Configuration for the main training loop, evaluation, and reproducibility."""

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
"""
3535

3636
import inspect
37-
from typing import Sequence, Callable
37+
import logging
38+
from typing import Sequence, Callable, Any
3839
from absl import app
3940
from flax import nnx
4041
from flax.linen import partitioning as nn_partitioning
@@ -199,7 +200,15 @@ class MaxTextDistillationTrainer(peft_trainer.PeftTrainer):
199200
(positions, segment_ids) are passed to the model.
200201
"""
201202

202-
def __init__(self, model, strategy: distillation_utils.DistillationStrategy, optimizer, training_config, **kwargs):
203+
def __init__(
204+
self,
205+
model,
206+
strategy: distillation_utils.DistillationStrategy,
207+
optimizer,
208+
training_config,
209+
student_freeze_param_filter: Callable[[Any], bool] | None = None,
210+
**kwargs,
211+
):
203212
# We pass a dummy optimizer to the base PeftTrainer temporarily to prevent PeftTrainer from eagerly
204213
# allocating massive optimizer states for the entire ModelBundle (including the frozen teacher) before
205214
# redefining the trainer optimizer here.
@@ -211,8 +220,22 @@ def __init__(self, model, strategy: distillation_utils.DistillationStrategy, opt
211220
# override optimizer to only use student_model.
212221
if training_config.gradient_accumulation_steps is not None and training_config.gradient_accumulation_steps > 1:
213222
optimizer = optax.MultiSteps(optimizer, training_config.gradient_accumulation_steps)
214-
wrt = nnx.LoRAParam if self._lora_enabled else nnx.Param
215-
self.optimizer = nnx.Optimizer(model.student_model, optimizer, wrt=wrt)
223+
224+
base_wrt = nnx.LoRAParam if getattr(self, "_lora_enabled", False) else nnx.Param
225+
if student_freeze_param_filter:
226+
227+
def wrt_filter(path, x):
228+
if not isinstance(x, base_wrt):
229+
return False
230+
freeze = student_freeze_param_filter(path)
231+
logging.info("Student model freezing info: Parameter %s; freeze=%s", path, freeze)
232+
return not freeze
233+
234+
self.wrt_filter = wrt_filter
235+
else:
236+
self.wrt_filter = base_wrt
237+
238+
self.optimizer = nnx.Optimizer(model.student_model, optimizer, wrt=self.wrt_filter)
216239

217240
# Detect if Tunix expects _train_step to return grad_norm by inspecting the source
218241
self._tunix_expects_grad_norm = False
@@ -282,7 +305,7 @@ def loss_wrapper(student, teacher, batch):
282305
# we only compute gradients for the student.
283306
grad_fn = nnx.value_and_grad(
284307
loss_wrapper,
285-
argnums=0,
308+
argnums=nnx.DiffState(0, self.wrt_filter),
286309
has_aux=True,
287310
)
288311

@@ -564,6 +587,12 @@ def train_distill(
564587
_log_config_details(student_config, "Student")
565588
student_model = get_maxtext_model(student_config, mesh)
566589

590+
student_params_to_update = getattr(student_config, "student_params_to_update", [])
591+
592+
def student_freeze_param_fn(path) -> bool:
593+
path_str = "/".join(str(p) for p in path)
594+
return not any(template in path_str for template in student_params_to_update)
595+
567596
if is_offline:
568597
max_logging.log("Offline Distillation: Skipping Teacher Model loading.")
569598
teacher_model = None
@@ -582,6 +611,7 @@ def train_distill(
582611
strategy=strategy,
583612
optimizer=optimizer,
584613
training_config=train_config,
614+
student_freeze_param_filter=student_freeze_param_fn if student_params_to_update else None,
585615
)
586616
trainer.is_managed_externally = True
587617
trainer._has_aux = True # pylint: disable=protected-access

tests/post_training/unit/train_distill_test.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ def test_prepare_inputs_logic(self):
126126
trainer.teacher_model = mock.Mock()
127127
trainer.model = mock.Mock()
128128
trainer.gen_model_input_fn = lambda x: {"inputs": {"some_key": "some_val"}}
129+
trainer.wrt_filter = lambda path, x: True # type: ignore
129130

130131
# 2. Setup Input
131132
# pylint: disable=unexpected-keyword-arg
@@ -153,6 +154,7 @@ def test_train_step_skips_teacher_forward_when_output_present(
153154
# pylint: disable=no-value-for-parameter
154155
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
155156
trainer.strategy = mock.Mock()
157+
trainer.wrt_filter = lambda path, x: True # type: ignore
156158

157159
# 2. Setup Batch WITH teacher_output
158160
mock_batch = {
@@ -205,6 +207,7 @@ def test_train_step_calls_teacher_forward_when_output_missing(
205207
# pylint: disable=no-value-for-parameter
206208
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
207209
trainer.strategy = mock.Mock()
210+
trainer.wrt_filter = lambda path, x: True # type: ignore
208211

209212
# 2. Setup Batch WITHOUT teacher_output
210213
mock_batch = {
@@ -278,6 +281,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_
278281
# pylint: disable=no-value-for-parameter
279282
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
280283
trainer.strategy = mock.Mock()
284+
trainer.wrt_filter = lambda path, x: True # type: ignore
281285

282286
# 2. Setup Batch WITH targets_segmentation
283287
mock_targets_segmentation = jnp.array([[1, 1, 0]])
@@ -579,6 +583,7 @@ def test_eval_step_calls_student_forward(self):
579583
# pylint: disable=no-value-for-parameter
580584
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
581585
trainer.strategy = mock.Mock()
586+
trainer.wrt_filter = lambda path, x: True # type: ignore
582587

583588
# 2. Setup Input Mocks
584589
raw_inputs = mock.Mock()
@@ -675,6 +680,7 @@ def test_post_process_train_step(self):
675680
"""Verifies metrics are moved from aux dict to the trainer buffer."""
676681
# pylint: disable=no-value-for-parameter
677682
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
683+
trainer.wrt_filter = lambda path, x: True # type: ignore
678684

679685
# Setup MetricsBuffer mock
680686
mock_buffer = mock.Mock()
@@ -723,6 +729,7 @@ def __call__(self, x):
723729
# pylint: disable=no-value-for-parameter
724730
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
725731
trainer.strategy = mock.Mock()
732+
trainer.wrt_filter = lambda path, x: True # type: ignore
726733

727734
dummy_batch = {
728735
"input_tokens": jnp.ones((1, 2)),
@@ -1121,6 +1128,92 @@ def test_main_online_mode_loads_teacher(
11211128
self.assertIs(model_bundle.student_model, mock_student_model)
11221129
self.assertIs(model_bundle.teacher_model, mock_teacher_model)
11231130

1131+
def test_student_freeze_param_filter(self):
1132+
"""Verifies that student_freeze_param_filter correctly freezes specified parameters."""
1133+
1134+
# 1. Setup a dummy model with multiple layers
1135+
class DummyModel(nnx.Module):
1136+
1137+
def __init__(self):
1138+
self.layer1 = nnx.Linear(in_features=2, out_features=2, rngs=nnx.Rngs(0))
1139+
self.layer2 = nnx.Linear(in_features=2, out_features=2, rngs=nnx.Rngs(1))
1140+
1141+
def __call__(self, input_tokens, **kwargs):
1142+
# Apply layers
1143+
return self.layer2(self.layer1(input_tokens))
1144+
1145+
student = DummyModel()
1146+
teacher = DummyModel()
1147+
model_bundle = train_distill.ModelBundle(teacher_model=teacher, student_model=student)
1148+
1149+
# Snapshot initial weights
1150+
initial_layer1_weights = student.layer1.kernel.get_value().copy()
1151+
initial_layer2_weights = student.layer2.kernel.get_value().copy()
1152+
1153+
# 2. Setup freeze filter (freeze layer1, train layer2)
1154+
def freeze_filter(path):
1155+
path_str = "/".join(str(p) for p in path)
1156+
return "layer1" in path_str
1157+
1158+
# 3. Setup Strategy and TrainingConfig
1159+
strategy = mock.Mock()
1160+
strategy.compute_loss.side_effect = lambda s_out, t_out, labels: (jnp.sum(s_out.logits), {"aux": 1.0})
1161+
strategy.create_labels.return_value = None
1162+
strategy.student_forward_fn = lambda model, **kw: distillation_utils.DistillationForwardOutput(
1163+
logits=model(kw["input_tokens"])
1164+
)
1165+
strategy.teacher_forward_fn = lambda model, **kw: distillation_utils.DistillationForwardOutput(
1166+
logits=model(kw["input_tokens"])
1167+
)
1168+
1169+
# pylint: disable=import-outside-toplevel
1170+
from tunix.sft import peft_trainer
1171+
1172+
train_config = peft_trainer.TrainingConfig(
1173+
max_steps=1,
1174+
eval_every_n_steps=0,
1175+
# checkpointing_options=ocp.CheckpointManagerOptions(create=False),
1176+
gradient_accumulation_steps=1,
1177+
)
1178+
1179+
# 4. Initialize Trainer
1180+
trainer = train_distill.MaxTextDistillationTrainer(
1181+
model=model_bundle,
1182+
strategy=strategy,
1183+
optimizer=optax.sgd(0.1),
1184+
training_config=train_config,
1185+
student_freeze_param_filter=freeze_filter,
1186+
)
1187+
trainer._lora_enabled = False
1188+
trainer.is_managed_externally = True
1189+
1190+
trainer = trainer.with_gen_model_input_fn(
1191+
lambda batch: {
1192+
"input_tokens": batch["input_tokens"],
1193+
"positions": None,
1194+
"attention_mask": None,
1195+
"decoder_segment_ids": None,
1196+
"targets": None,
1197+
"teacher_output": distillation_utils.DistillationForwardOutput(logits=jnp.ones((1, 2))),
1198+
}
1199+
)
1200+
1201+
dummy_batch = {"input_tokens": jnp.ones((1, 2))}
1202+
1203+
# 5. Execute Pass
1204+
trainer._train_step(model_bundle, trainer.optimizer, dummy_batch)
1205+
1206+
# 6. Verify layer1 is unchanged (frozen)
1207+
np.testing.assert_allclose(
1208+
student.layer1.kernel.get_value(),
1209+
initial_layer1_weights,
1210+
err_msg="layer1 weights should be frozen and remain unchanged.",
1211+
)
1212+
1213+
# Verify layer2 has changed (trained)
1214+
is_layer2_unchanged = np.allclose(student.layer2.kernel.get_value(), initial_layer2_weights)
1215+
self.assertFalse(is_layer2_unchanged, msg="layer2 weights should have updated.")
1216+
11241217

11251218
if __name__ == "__main__":
11261219
absltest.main()

0 commit comments

Comments
 (0)