Skip to content

Commit b2bbf02

Browse files
committed
add gradient_accumulation to distill pipeline
1 parent f2d2ec8 commit b2bbf02

2 files changed

Lines changed: 74 additions & 1 deletion

File tree

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ def __init__(self, model, strategy, optimizer, training_config, **kwargs):
204204
self.strategy = strategy
205205

206206
# override optimizer to only use student_model.
207+
if training_config.gradient_accumulation_steps is not None and training_config.gradient_accumulation_steps > 1:
208+
optimizer = optax.MultiSteps(optimizer, training_config.gradient_accumulation_steps)
207209
wrt = nnx.LoRAParam if self._lora_enabled else nnx.Param
208210
self.optimizer = nnx.Optimizer(model.student_model, optimizer, wrt=wrt)
209211

@@ -457,7 +459,8 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
457459
)
458460

459461
# 4. Optimizer & Config
460-
optimizer = get_distillation_optimizer(student_config, student_config.steps)
462+
total_updates = student_config.steps // student_config.gradient_accumulation_steps
463+
optimizer = get_distillation_optimizer(student_config, total_updates)
461464

462465
checkpointing_options = checkpoint.CheckpointManagerOptions(
463466
save_interval_steps=student_config.checkpoint_period,
@@ -486,6 +489,7 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
486489
profiler_options=profiler_options,
487490
checkpoint_root_directory=student_config.checkpoint_dir,
488491
checkpointing_options=checkpointing_options,
492+
gradient_accumulation_steps=student_config.gradient_accumulation_steps,
489493
)
490494

491495
# 5. Data Iterators (Init BEFORE Trainer)

tests/unit/train_distill_test.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from unittest import mock
2727
import jax
2828
import jax.numpy as jnp
29+
from flax import nnx
2930
import numpy as np
3031
import optax
3132
import orbax.checkpoint as ocp
@@ -537,6 +538,74 @@ def test_post_process_train_step(self):
537538
values_list = mock_buffer.additional_metrics["distill/kl_div"][0]
538539
self.assertEqual(values_list[0], 0.5)
539540

541+
def test_gradient_accumulation_requires_k_passes_for_update(self):
542+
"""Verifies that weights only update after k distinct forward passes."""
543+
544+
# 1. Setup a minimal NNX model
545+
class DummyModel(nnx.Module):
546+
547+
def __init__(self):
548+
self.linear = nnx.Linear(in_features=2, out_features=2, rngs=nnx.Rngs(0))
549+
550+
def __call__(self, x):
551+
return self.linear(x)
552+
553+
student = DummyModel()
554+
teacher = DummyModel()
555+
model_bundle = train_distill.ModelBundle(teacher_model=teacher, student_model=student)
556+
557+
# Snapshot the initial weights
558+
initial_weights = student.linear.kernel.value.copy()
559+
560+
# 2. Setup Optimizer with MultiSteps (Accumulate over 2 passes)
561+
base_optimizer = optax.sgd(learning_rate=0.1)
562+
accumulating_optimizer = optax.MultiSteps(base_optimizer, every_k_schedule=2)
563+
nnx_opt = nnx.Optimizer(student, accumulating_optimizer, wrt=nnx.Param)
564+
565+
# 3. Initialize Trainer and Mocks
566+
# pylint: disable=no-value-for-parameter
567+
trainer = train_distill.MaxTextDistillationTrainer.__new__(train_distill.MaxTextDistillationTrainer)
568+
trainer.strategy = mock.Mock()
569+
570+
dummy_batch = {
571+
"input_tokens": jnp.ones((1, 2)),
572+
"positions": None,
573+
"targets": None,
574+
"teacher_output": jnp.array([1.0, 1.0]),
575+
}
576+
trainer.gen_model_input_fn = mock.Mock(return_value=dummy_batch)
577+
trainer.strategy.labels_fn.return_value = None
578+
579+
# 4. Mock the forward pass to COUNT how many times it executes
580+
# We wrap the actual dummy model execution in a mock to track it.
581+
mock_student_forward = mock.Mock(side_effect=lambda model, **kwargs: model(dummy_batch["input_tokens"]))
582+
trainer.strategy.student_forward_fn = mock_student_forward
583+
584+
trainer.strategy.compute_loss.side_effect = lambda s_out, t_out, labels: (jnp.sum(s_out), {"aux": 1.0})
585+
586+
# --- EXECUTE PASS 1 ---
587+
trainer._train_step(model_bundle, nnx_opt, dummy_batch)
588+
589+
# ASSERTIONS AFTER PASS 1:
590+
# Verify exactly ONE forward pass happened
591+
self.assertEqual(mock_student_forward.call_count, 1)
592+
593+
# Verify weights are completely UNCHANGED
594+
np.testing.assert_allclose(
595+
student.linear.kernel.value, initial_weights, err_msg="Weights should not update on the first pass."
596+
)
597+
598+
# --- EXECUTE PASS 2 ---
599+
trainer._train_step(model_bundle, nnx_opt, dummy_batch)
600+
601+
# ASSERTIONS AFTER PASS 2:
602+
# Verify exactly TWO forward passes have now happened
603+
self.assertEqual(mock_student_forward.call_count, 2)
604+
605+
# Verify weights HAVE changed
606+
with self.assertRaises(AssertionError, msg="Weights should have updated on the second pass."):
607+
np.testing.assert_allclose(student.linear.kernel.value, initial_weights)
608+
540609

541610
if __name__ == "__main__":
542611
absltest.main()

0 commit comments

Comments
 (0)