Skip to content

Commit 9f6b09a

Browse files
Merge pull request #3478 from AI-Hypercomputer:vladk/distill-eod-refactor
PiperOrigin-RevId: 888846846
2 parents 45e21f8 + ced6b20 commit 9f6b09a

3 files changed

Lines changed: 159 additions & 45 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
# Copyright 2023-2026 Google LLC
1+
# Copyright 20232026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
66
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
7+
# https://www.apache.org/licenses/LICENSE-2.0
88
#
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -22,6 +22,7 @@
2222
import tensorflow as tf
2323
from array_record.python import array_record_module
2424

25+
import abc
2526
from typing import Any, Iterator, Optional, List, Callable
2627

2728
import flax
@@ -182,21 +183,87 @@ def __next__(self) -> MaxTextTrainingInput:
182183
# -----------------------------------------------------------------------------
183184
# Distillation Strategy
184185
# -----------------------------------------------------------------------------
185-
class CombinedDistillationStrategy:
186+
187+
188+
class DistillationStrategy(abc.ABC):
189+
"""Abstract base class for MaxText Distillation Strategies."""
190+
191+
def __init__(
192+
self, student_forward_fn: Callable, teacher_forward_fn: Callable, vocab_size: int, pad_id: int = 0, **kwargs
193+
):
194+
"""Initializes the generic distillation strategy.
195+
196+
Args:
197+
student_forward_fn: Function to compute student model outputs.
198+
teacher_forward_fn: Function to compute teacher model outputs.
199+
vocab_size: The size of the model's vocabulary.
200+
pad_id: The ID used for padding tokens.
201+
"""
202+
self.student_forward_fn = student_forward_fn
203+
self.teacher_forward_fn = teacher_forward_fn
204+
self.vocab_size = vocab_size
205+
self.pad_id = pad_id
206+
207+
@abc.abstractmethod
208+
def compute_loss(
209+
self,
210+
student_output: "DistillationForwardOutput",
211+
teacher_output: "DistillationForwardOutput",
212+
labels: jax.Array,
213+
) -> tuple[jax.Array, dict[str, jax.Array]]:
214+
"""Computes the distillation loss.
215+
216+
Args:
217+
student_output: The forward pass output of the student model.
218+
teacher_output: The forward pass output of the frozen teacher model.
219+
labels: The masked one-hot encoded ground truth labels.
220+
221+
Returns:
222+
A tuple containing the scalar loss and a dictionary of auxiliary metrics
223+
(e.g., {"distill/soft_loss": ..., "distill/total_loss": ...})
224+
"""
225+
raise NotImplementedError
226+
227+
@abc.abstractmethod
228+
def compute_eval_loss(
229+
self,
230+
student_output: "DistillationForwardOutput",
231+
labels: jax.Array,
232+
) -> tuple[jax.Array, dict[str, jax.Array]]:
233+
"""Computes the evaluation loss (typically just the task loss).
234+
235+
Args:
236+
student_output: The forward pass output of the student model.
237+
labels: The masked one-hot encoded ground truth labels.
238+
239+
Returns:
240+
A tuple containing the scalar loss and an empty (or auxiliary) dict.
241+
"""
242+
raise NotImplementedError
243+
244+
@abc.abstractmethod
245+
def create_labels(self, targets: jax.Array, targets_segmentation: Optional[jax.Array] = None, **kwargs) -> jax.Array:
246+
"""
247+
Creates labels tensor to compute the loss
248+
"""
249+
raise NotImplementedError
250+
251+
252+
class CombinedDistillationStrategy(DistillationStrategy):
186253
"""Strategy that returns detailed metrics for TensorBoard."""
187254

188255
def __init__(
189256
self,
190257
student_forward_fn: Callable[..., DistillationForwardOutput],
191258
teacher_forward_fn: Callable[..., DistillationForwardOutput],
192-
labels_fn: Callable[..., jax.Array],
259+
pad_id: int = 0,
193260
temperature: float = 2.0,
194261
alpha: float = 0.5,
195262
beta_feature: float = 0.0,
196263
layer_indices: Optional[List[int]] = None,
197264
feature_loss_fn: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
198265
cosine_distance_axis: int | tuple[int, ...] = -1,
199-
sft_mode: bool = False,
266+
vocab_size: int = 0,
200267
):
201268
"""Initializes the Combined strategy using tunix logit.LogitStrategy.
202269
@@ -213,9 +280,14 @@ def __init__(
213280
cosine_distance_axis: The axis to use for cosine distance computation if
214281
feature_loss_fn is not provided. Defaults to -1.
215282
"""
216-
self.student_forward_fn = student_forward_fn
217-
self.teacher_forward_fn = teacher_forward_fn
218-
self.labels_fn = labels_fn
283+
284+
super().__init__(
285+
student_forward_fn=student_forward_fn,
286+
teacher_forward_fn=teacher_forward_fn,
287+
vocab_size=vocab_size,
288+
pad_id=pad_id,
289+
)
290+
219291
self.temperature = temperature
220292
self.alpha = alpha
221293
self.beta_feature = beta_feature
@@ -226,7 +298,6 @@ def __init__(
226298
self.feature_loss_fn = lambda student_features, teacher_features: jnp.mean(
227299
optax.cosine_distance(student_features, teacher_features, axis=cosine_distance_axis)
228300
)
229-
self.sft_mode = sft_mode
230301

231302
def compute_loss(
232303
self,
@@ -253,10 +324,9 @@ def compute_loss(
253324

254325
log_student_probs_temp = jax.nn.log_softmax(s_logits / self.temperature, axis=-1)
255326
teacher_probs_temp = jax.nn.softmax(t_logits / self.temperature, axis=-1)
256-
257327
# labels are supposed to have all sft masks applied by this moment
258-
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True) if self.sft_mode else None
259-
mean_mask = jnp.squeeze(labels_mask, axis=-1) if labels_mask is not None else None
328+
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
329+
mean_mask = jnp.squeeze(labels_mask, axis=-1)
260330

261331
# KL(Teacher || Student)
262332
kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp, where=labels_mask)
@@ -297,7 +367,7 @@ def compute_loss(
297367
metrics = {
298368
"distill/soft_loss": soft_loss,
299369
"distill/hard_loss": hard_loss,
300-
"distill/kl_div": jnp.mean(kl_div),
370+
"distill/kl_div": jnp.mean(kl_div, where=mean_mask),
301371
"distill/teacher_loss": teacher_hard_loss,
302372
"distill/out_proj_feature_loss": feature_loss,
303373
"distill/total_loss": total_loss,
@@ -316,12 +386,24 @@ def compute_eval_loss(
316386
# Parent logic for task loss
317387
# We re-implement simple CE here to ensure float32 casting
318388
s_logits = student_output.logits.astype(jnp.float32)
319-
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels)
320-
task_loss = jnp.mean(ce_loss)
389+
390+
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
391+
mean_mask = jnp.squeeze(labels_mask, axis=-1)
392+
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels, where=labels_mask)
393+
task_loss = jnp.mean(ce_loss, where=mean_mask)
321394

322395
# Must return a tuple because _has_aux=True expects it
323396
return task_loss, {}
324397

398+
def create_labels(self, targets, targets_segmentation=None, **kwargs):
399+
"""Converts integer targets to masked one-hot vectors for hard label loss."""
400+
del kwargs # Unused
401+
one_hot = jax.nn.one_hot(targets, self.vocab_size)
402+
mask = jnp.not_equal(targets, self.pad_id).astype(one_hot.dtype)[..., None]
403+
if targets_segmentation is not None:
404+
mask = mask * (targets_segmentation != 0)[..., None]
405+
return one_hot * mask
406+
325407

326408
# -----------------------------------------------------------------------------
327409
# Checkpoint Manager

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
# Copyright 2023-2026 Google LLC
1+
# Copyright 20232026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
66
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
7+
# https://www.apache.org/licenses/LICENSE-2.0
88
#
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -199,7 +199,7 @@ class MaxTextDistillationTrainer(peft_trainer.PeftTrainer):
199199
(positions, segment_ids) are passed to the model.
200200
"""
201201

202-
def __init__(self, model, strategy, optimizer, training_config, **kwargs):
202+
def __init__(self, model, strategy: distillation_utils.DistillationStrategy, optimizer, training_config, **kwargs):
203203
# We pass a dummy optimizer to the base PeftTrainer temporarily to prevent PeftTrainer from eagerly
204204
# allocating massive optimizer states for the entire ModelBundle (including the frozen teacher) before
205205
# redefining the trainer optimizer here.
@@ -275,7 +275,7 @@ def loss_wrapper(student, teacher, batch):
275275
cache=None,
276276
)
277277
# we should apply a mask for labels to disable segment-separator tokens
278-
labels = self.strategy.labels_fn(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
278+
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
279279
return self.strategy.compute_loss(student_output, teacher_output, labels)
280280

281281
# Because student is the 0th argument, argnums=0 guarantees
@@ -308,7 +308,7 @@ def _eval_step(self, model, inputs):
308308
decoder_segment_ids=inputs.get("decoder_segment_ids"),
309309
cache=None,
310310
)
311-
labels = self.strategy.labels_fn(inputs["targets"])
311+
labels = self.strategy.create_labels(inputs["targets"], targets_segmentation=inputs.get("targets_segmentation", None))
312312
return self.strategy.compute_eval_loss(student_output, labels)
313313

314314
def _prepare_inputs(
@@ -470,14 +470,6 @@ def build_training_components(
470470
pad_id = tok.pad_id if tok.pad_id is not None else 0
471471

472472
# 3. Define Distillation Strategy
473-
def labels_fn(targets, targets_segmentation=None, **kwargs):
474-
"""Converts integer targets to masked one-hot vectors for hard label loss."""
475-
del kwargs # Unused
476-
one_hot = jax.nn.one_hot(targets, student_config.vocab_size)
477-
mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None]
478-
if targets_segmentation is not None:
479-
mask = mask * (targets_segmentation != 0)[..., None]
480-
return one_hot * mask
481473

482474
# Both Student and Teacher use the same forward logic via the adapter
483475
student_forward_fn = create_forward_fn(student_config)
@@ -487,12 +479,12 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
487479
strategy = distillation_utils.CombinedDistillationStrategy(
488480
student_forward_fn=student_forward_fn,
489481
teacher_forward_fn=teacher_forward_fn,
490-
labels_fn=labels_fn,
482+
pad_id=pad_id,
491483
temperature=student_config.distill_temperature,
492484
alpha=student_config.distill_alpha,
493485
beta_feature=student_config.distill_beta,
494486
layer_indices=student_config.distill_layer_indices,
495-
sft_mode=student_config.use_sft,
487+
vocab_size=student_config.vocab_size,
496488
)
497489

498490
# 4. Optimizer & Config

tests/post_training/unit/train_distill_test.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def test_train_step_calls_teacher_forward_when_output_missing(
258258
)
259259

260260
# Verify loss computation and optimizer update
261-
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
261+
trainer.strategy.create_labels.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
262262
trainer.strategy.compute_loss.assert_called_once()
263263
optimizer.update.assert_called_once_with(student_model, mock_grads)
264264

@@ -307,7 +307,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_
307307
loss_wrapper(student_model, teacher_model, mock_batch)
308308

309309
# 6. Assertions
310-
trainer.strategy.labels_fn.assert_called_once_with(
310+
trainer.strategy.create_labels.assert_called_once_with(
311311
mock_batch["targets"], targets_segmentation=mock_targets_segmentation
312312
)
313313
trainer.strategy.student_forward_fn.assert_called_once_with(
@@ -378,12 +378,11 @@ def _test_monitored_strategy(self, sft_mode: bool):
378378
strategy = distillation_utils.CombinedDistillationStrategy(
379379
student_forward_fn=lambda m, **k: None,
380380
teacher_forward_fn=lambda m, **k: None,
381-
labels_fn=lambda t: t,
381+
vocab_size=4,
382382
temperature=1.0,
383383
alpha=0.5,
384384
beta_feature=1.0,
385385
layer_indices=None,
386-
sft_mode=sft_mode,
387386
)
388387

389388
# Dummy inputs (batch=1, seq=2, vocab=4)
@@ -426,18 +425,15 @@ def _test_monitored_strategy(self, sft_mode: bool):
426425
self.assertLess(metrics["distill/kl_div"], 1e-5)
427426
self.assertLess(metrics["distill/out_proj_feature_loss"], 1e-5)
428427

429-
def test_strategy_compute_eval_loss(self):
430-
self._verify_strategy_compute_eval_loss(sft_mode=False)
431-
432-
def _verify_strategy_compute_eval_loss(self, sft_mode):
428+
def verify_strategy_compute_eval_loss(self):
433429
"""Covers MonitoredLogitStrategy.compute_eval_loss."""
434430
strategy = distillation_utils.CombinedDistillationStrategy(
435431
student_forward_fn=mock.Mock(),
436432
teacher_forward_fn=mock.Mock(),
437-
labels_fn=mock.Mock(),
433+
vocab_size=4,
434+
# student_config=mock_config,
438435
temperature=1.0,
439436
alpha=0.5,
440-
sft_mode=sft_mode,
441437
)
442438
# Case where feature loss is enabled
443439
logits = distillation_utils.DistillationForwardOutput(
@@ -459,8 +455,51 @@ def _verify_strategy_compute_eval_loss(self, sft_mode):
459455
self.assertTrue(isinstance(loss, jax.Array))
460456
self.assertEqual(aux, {})
461457

462-
def test_strategy_compute_eval_loss_sft(self):
463-
self._verify_strategy_compute_eval_loss(sft_mode=True)
458+
def test_strategy_ignores_segmentation_zero_tokens(self):
459+
"""Verifies that 0 tokens in targets_segmentation are ignored in loss computation."""
460+
strategy = distillation_utils.CombinedDistillationStrategy(
461+
student_forward_fn=mock.Mock(),
462+
teacher_forward_fn=mock.Mock(),
463+
vocab_size=4,
464+
temperature=1.0,
465+
alpha=0.5,
466+
pad_id=0,
467+
)
468+
469+
# 1. Leverage the targets_segmentation tensor and put a 0 token in between.
470+
# Token 1 is a delimiter (targets_segmentation = 0).
471+
targets = jnp.array([[2, 1, 3]])
472+
targets_segmentation = jnp.array([[1, 0, 1]])
473+
474+
# 2. Create labels with the zeroed out segment delimiter mask.
475+
labels = strategy.create_labels(targets, targets_segmentation=targets_segmentation)
476+
477+
# Student has all predictions incorrect
478+
s_logits = jnp.array(
479+
[
480+
[
481+
[10.0, -10.0, -10.0, -10.0],
482+
[-10.0, 10.0, -10.0, -10.0],
483+
[-10.0, 10.0, -10.0, -10.0],
484+
]
485+
] # correct
486+
)
487+
student_output = distillation_utils.DistillationForwardOutput(logits=s_logits, out_projection_activations=None)
488+
489+
# Teacher perfectly predicts the target for Token 0 and Token 2, and class 1 for Token 1
490+
t_logits = jnp.array([[[-10.0, -10.0, 10.0, -10.0], [10.0, -10.0, -10.0, -10.0], [-10.0, -10.0, -10.0, 10.0]]])
491+
teacher_output = distillation_utils.DistillationForwardOutput(logits=t_logits, out_projection_activations=None)
492+
493+
# 3. Call compute_loss()
494+
_, metrics = strategy.compute_loss(student_output, teacher_output, labels)
495+
496+
# all tokens are predicted incorrect so the loss should be 10*2 since
497+
# token at position 1 should be excluded from the loss
498+
# mean kl_div should also be equal to 20
499+
self.assertTrue(19.0 < metrics["distill/hard_loss"] < 21.0)
500+
self.assertTrue(19.0 < metrics["distill/soft_loss"] < 21.0)
501+
self.assertTrue(19.0 < metrics["distill/kl_div"] < 21.0)
502+
self.assertTrue(metrics["distill/teacher_loss"] == 0.0)
464503

465504
def test_setup_pipeline_grain_enabled(self):
466505
"""Covers setup_checkpoint_manager_and_restore when Grain IS detected."""
@@ -549,6 +588,7 @@ def test_eval_step_calls_student_forward(self):
549588
"attention_mask": mock.Mock(),
550589
"decoder_segment_ids": mock.Mock(),
551590
"targets": mock.Mock(),
591+
"targets_segmentation": None,
552592
}
553593
trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch)
554594

@@ -562,7 +602,7 @@ def test_eval_step_calls_student_forward(self):
562602
trainer.strategy.student_forward_fn.return_value = mock_student_output
563603

564604
mock_labels = mock.Mock()
565-
trainer.strategy.labels_fn.return_value = mock_labels
605+
trainer.strategy.create_labels.return_value = mock_labels
566606

567607
mock_loss = mock.Mock()
568608
trainer.strategy.compute_eval_loss.return_value = mock_loss
@@ -591,7 +631,7 @@ def test_eval_step_calls_student_forward(self):
591631
trainer.strategy.teacher_forward_fn.assert_not_called()
592632

593633
# Verify loss computation pipeline
594-
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"])
634+
trainer.strategy.create_labels.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
595635
trainer.strategy.compute_eval_loss.assert_called_once_with(mock_student_output, mock_labels)
596636

597637
# Verify it returns the correct loss
@@ -691,7 +731,7 @@ def __call__(self, x):
691731
"teacher_output": jnp.array([1.0, 1.0]),
692732
}
693733
trainer.gen_model_input_fn = mock.Mock(return_value=dummy_batch)
694-
trainer.strategy.labels_fn.return_value = None
734+
trainer.strategy.create_labels.return_value = None
695735

696736
# 4. Mock the forward pass to COUNT how many times it executes
697737
# We wrap the actual dummy model execution in a mock to track it.

0 commit comments

Comments
 (0)