Skip to content

Commit ced6b20

Browse files
committed
Fix eod masking + strategy refactoring
1 parent 9e786c8 commit ced6b20

3 files changed

Lines changed: 159 additions & 45 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 97 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
# Copyright 2023-2026 Google LLC
1+
# Copyright 20232026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
66
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
7+
# https://www.apache.org/licenses/LICENSE-2.0
88
#
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -22,6 +22,7 @@
2222
import tensorflow as tf
2323
from array_record.python import array_record_module
2424

25+
import abc
2526
from typing import Any, Iterator, Optional, List, Callable
2627

2728
import flax
@@ -182,21 +183,87 @@ def __next__(self) -> MaxTextTrainingInput:
182183
# -----------------------------------------------------------------------------
183184
# Distillation Strategy
184185
# -----------------------------------------------------------------------------
185-
class CombinedDistillationStrategy:
186+
187+
188+
class DistillationStrategy(abc.ABC):
189+
"""Abstract base class for MaxText Distillation Strategies."""
190+
191+
def __init__(
192+
self, student_forward_fn: Callable, teacher_forward_fn: Callable, vocab_size: int, pad_id: int = 0, **kwargs
193+
):
194+
"""Initializes the generic distillation strategy.
195+
196+
Args:
197+
student_forward_fn: Function to compute student model outputs.
198+
teacher_forward_fn: Function to compute teacher model outputs.
199+
vocab_size: The size of the model's vocabulary.
200+
pad_id: The ID used for padding tokens.
201+
"""
202+
self.student_forward_fn = student_forward_fn
203+
self.teacher_forward_fn = teacher_forward_fn
204+
self.vocab_size = vocab_size
205+
self.pad_id = pad_id
206+
207+
@abc.abstractmethod
208+
def compute_loss(
209+
self,
210+
student_output: "DistillationForwardOutput",
211+
teacher_output: "DistillationForwardOutput",
212+
labels: jax.Array,
213+
) -> tuple[jax.Array, dict[str, jax.Array]]:
214+
"""Computes the distillation loss.
215+
216+
Args:
217+
student_output: The forward pass output of the student model.
218+
teacher_output: The forward pass output of the frozen teacher model.
219+
labels: The masked one-hot encoded ground truth labels.
220+
221+
Returns:
222+
A tuple containing the scalar loss and a dictionary of auxiliary metrics
223+
(e.g., {"distill/soft_loss": ..., "distill/total_loss": ...})
224+
"""
225+
raise NotImplementedError
226+
227+
@abc.abstractmethod
228+
def compute_eval_loss(
229+
self,
230+
student_output: "DistillationForwardOutput",
231+
labels: jax.Array,
232+
) -> tuple[jax.Array, dict[str, jax.Array]]:
233+
"""Computes the evaluation loss (typically just the task loss).
234+
235+
Args:
236+
student_output: The forward pass output of the student model.
237+
labels: The masked one-hot encoded ground truth labels.
238+
239+
Returns:
240+
A tuple containing the scalar loss and an empty (or auxiliary) dict.
241+
"""
242+
raise NotImplementedError
243+
244+
@abc.abstractmethod
245+
def create_labels(self, targets: jax.Array, targets_segmentation: Optional[jax.Array] = None, **kwargs) -> jax.Array:
246+
"""
247+
Creates labels tensor to compute the loss
248+
"""
249+
raise NotImplementedError
250+
251+
252+
class CombinedDistillationStrategy(DistillationStrategy):
186253
"""Strategy that returns detailed metrics for TensorBoard."""
187254

188255
def __init__(
189256
self,
190257
student_forward_fn: Callable[..., DistillationForwardOutput],
191258
teacher_forward_fn: Callable[..., DistillationForwardOutput],
192-
labels_fn: Callable[..., jax.Array],
259+
pad_id: int = 0,
193260
temperature: float = 2.0,
194261
alpha: float = 0.5,
195262
beta_feature: float = 0.0,
196263
layer_indices: Optional[List[int]] = None,
197264
feature_loss_fn: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
198265
cosine_distance_axis: int | tuple[int, ...] = -1,
199-
sft_mode: bool = False,
266+
vocab_size: int = 0,
200267
):
201268
"""Initializes the Combined strategy using tunix logit.LogitStrategy.
202269
@@ -213,9 +280,14 @@ def __init__(
213280
cosine_distance_axis: The axis to use for cosine distance computation if
214281
feature_loss_fn is not provided. Defaults to -1.
215282
"""
216-
self.student_forward_fn = student_forward_fn
217-
self.teacher_forward_fn = teacher_forward_fn
218-
self.labels_fn = labels_fn
283+
284+
super().__init__(
285+
student_forward_fn=student_forward_fn,
286+
teacher_forward_fn=teacher_forward_fn,
287+
vocab_size=vocab_size,
288+
pad_id=pad_id,
289+
)
290+
219291
self.temperature = temperature
220292
self.alpha = alpha
221293
self.beta_feature = beta_feature
@@ -226,7 +298,6 @@ def __init__(
226298
self.feature_loss_fn = lambda student_features, teacher_features: jnp.mean(
227299
optax.cosine_distance(student_features, teacher_features, axis=cosine_distance_axis)
228300
)
229-
self.sft_mode = sft_mode
230301

231302
def compute_loss(
232303
self,
@@ -253,10 +324,9 @@ def compute_loss(
253324

254325
log_student_probs_temp = jax.nn.log_softmax(s_logits / self.temperature, axis=-1)
255326
teacher_probs_temp = jax.nn.softmax(t_logits / self.temperature, axis=-1)
256-
257327
# labels are supposed to have all sft masks applied by this moment
258-
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True) if self.sft_mode else None
259-
mean_mask = jnp.squeeze(labels_mask, axis=-1) if labels_mask is not None else None
328+
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
329+
mean_mask = jnp.squeeze(labels_mask, axis=-1)
260330

261331
# KL(Teacher || Student)
262332
kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp, where=labels_mask)
@@ -297,7 +367,7 @@ def compute_loss(
297367
metrics = {
298368
"distill/soft_loss": soft_loss,
299369
"distill/hard_loss": hard_loss,
300-
"distill/kl_div": jnp.mean(kl_div),
370+
"distill/kl_div": jnp.mean(kl_div, where=mean_mask),
301371
"distill/teacher_loss": teacher_hard_loss,
302372
"distill/out_proj_feature_loss": feature_loss,
303373
"distill/total_loss": total_loss,
@@ -316,12 +386,24 @@ def compute_eval_loss(
316386
# Parent logic for task loss
317387
# We re-implement simple CE here to ensure float32 casting
318388
s_logits = student_output.logits.astype(jnp.float32)
319-
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels)
320-
task_loss = jnp.mean(ce_loss)
389+
390+
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
391+
mean_mask = jnp.squeeze(labels_mask, axis=-1)
392+
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels, where=labels_mask)
393+
task_loss = jnp.mean(ce_loss, where=mean_mask)
321394

322395
# Must return a tuple because _has_aux=True expects it
323396
return task_loss, {}
324397

398+
def create_labels(self, targets, targets_segmentation=None, **kwargs):
399+
"""Converts integer targets to masked one-hot vectors for hard label loss."""
400+
del kwargs # Unused
401+
one_hot = jax.nn.one_hot(targets, self.vocab_size)
402+
mask = jnp.not_equal(targets, self.pad_id).astype(one_hot.dtype)[..., None]
403+
if targets_segmentation is not None:
404+
mask = mask * (targets_segmentation != 0)[..., None]
405+
return one_hot * mask
406+
325407

326408
# -----------------------------------------------------------------------------
327409
# Checkpoint Manager

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
# Copyright 2023-2026 Google LLC
1+
# Copyright 20232026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
55
# You may obtain a copy of the License at
66
#
7-
# http://www.apache.org/licenses/LICENSE-2.0
7+
# https://www.apache.org/licenses/LICENSE-2.0
88
#
99
# Unless required by applicable law or agreed to in writing, software
1010
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -197,7 +197,7 @@ class MaxTextDistillationTrainer(peft_trainer.PeftTrainer):
197197
(positions, segment_ids) are passed to the model.
198198
"""
199199

200-
def __init__(self, model, strategy, optimizer, training_config, **kwargs):
200+
def __init__(self, model, strategy: distillation_utils.DistillationStrategy, optimizer, training_config, **kwargs):
201201
# We pass a dummy optimizer to the base PeftTrainer temporarily to prevent PeftTrainer from eagerly
202202
# allocating massive optimizer states for the entire ModelBundle (including the frozen teacher) before
203203
# redefining the trainer optimizer here.
@@ -245,7 +245,7 @@ def loss_wrapper(student, teacher, batch):
245245
cache=None,
246246
)
247247
# we should apply a mask for labels to disable segment-separator tokens
248-
labels = self.strategy.labels_fn(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
248+
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
249249
return self.strategy.compute_loss(student_output, teacher_output, labels)
250250

251251
# Because student is the 0th argument, argnums=0 guarantees
@@ -274,7 +274,7 @@ def _eval_step(self, model, inputs):
274274
decoder_segment_ids=inputs.get("decoder_segment_ids"),
275275
cache=None,
276276
)
277-
labels = self.strategy.labels_fn(inputs["targets"])
277+
labels = self.strategy.create_labels(inputs["targets"], targets_segmentation=inputs.get("targets_segmentation", None))
278278
return self.strategy.compute_eval_loss(student_output, labels)
279279

280280
def _prepare_inputs(
@@ -454,14 +454,6 @@ def train_distill(
454454
teacher_model.eval()
455455

456456
# 3. Define Distillation Strategy
457-
def labels_fn(targets, targets_segmentation=None, **kwargs):
458-
"""Converts integer targets to masked one-hot vectors for hard label loss."""
459-
del kwargs # Unused
460-
one_hot = jax.nn.one_hot(targets, student_config.vocab_size)
461-
mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None]
462-
if targets_segmentation is not None:
463-
mask = mask * (targets_segmentation != 0)[..., None]
464-
return one_hot * mask
465457

466458
# Both Student and Teacher use the same forward logic via the adapter
467459
student_forward_fn = create_forward_fn(student_config)
@@ -471,12 +463,12 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
471463
strategy = distillation_utils.CombinedDistillationStrategy(
472464
student_forward_fn=student_forward_fn,
473465
teacher_forward_fn=teacher_forward_fn,
474-
labels_fn=labels_fn,
466+
pad_id=pad_id,
475467
temperature=student_config.distill_temperature,
476468
alpha=student_config.distill_alpha,
477469
beta_feature=student_config.distill_beta,
478470
layer_indices=student_config.distill_layer_indices,
479-
sft_mode=student_config.use_sft,
471+
vocab_size=student_config.vocab_size,
480472
)
481473

482474
# 4. Optimizer & Config

tests/post_training/unit/train_distill_test.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_train_step_calls_teacher_forward_when_output_missing(self, mock_value_a
247247
)
248248

249249
# Verify loss computation and optimizer update
250-
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
250+
trainer.strategy.create_labels.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
251251
trainer.strategy.compute_loss.assert_called_once()
252252
optimizer.update.assert_called_once_with(student_model, mock_grads)
253253

@@ -291,7 +291,7 @@ def test_train_step_passes_targets_segmentation(self, mock_value_and_grad, mock_
291291
loss_wrapper(student_model, teacher_model, mock_batch)
292292

293293
# 6. Assertions
294-
trainer.strategy.labels_fn.assert_called_once_with(
294+
trainer.strategy.create_labels.assert_called_once_with(
295295
mock_batch["targets"], targets_segmentation=mock_targets_segmentation
296296
)
297297
trainer.strategy.student_forward_fn.assert_called_once_with(
@@ -362,12 +362,11 @@ def _test_monitored_strategy(self, sft_mode: bool):
362362
strategy = distillation_utils.CombinedDistillationStrategy(
363363
student_forward_fn=lambda m, **k: None,
364364
teacher_forward_fn=lambda m, **k: None,
365-
labels_fn=lambda t: t,
365+
vocab_size=4,
366366
temperature=1.0,
367367
alpha=0.5,
368368
beta_feature=1.0,
369369
layer_indices=None,
370-
sft_mode=sft_mode,
371370
)
372371

373372
# Dummy inputs (batch=1, seq=2, vocab=4)
@@ -410,18 +409,15 @@ def _test_monitored_strategy(self, sft_mode: bool):
410409
self.assertLess(metrics["distill/kl_div"], 1e-5)
411410
self.assertLess(metrics["distill/out_proj_feature_loss"], 1e-5)
412411

413-
def test_strategy_compute_eval_loss(self):
414-
self._verify_strategy_compute_eval_loss(sft_mode=False)
415-
416-
def _verify_strategy_compute_eval_loss(self, sft_mode):
412+
def verify_strategy_compute_eval_loss(self):
417413
"""Covers MonitoredLogitStrategy.compute_eval_loss."""
418414
strategy = distillation_utils.CombinedDistillationStrategy(
419415
student_forward_fn=mock.Mock(),
420416
teacher_forward_fn=mock.Mock(),
421-
labels_fn=mock.Mock(),
417+
vocab_size=4,
418+
# student_config=mock_config,
422419
temperature=1.0,
423420
alpha=0.5,
424-
sft_mode=sft_mode,
425421
)
426422
# Case where feature loss is enabled
427423
logits = distillation_utils.DistillationForwardOutput(
@@ -443,8 +439,51 @@ def _verify_strategy_compute_eval_loss(self, sft_mode):
443439
self.assertTrue(isinstance(loss, jax.Array))
444440
self.assertEqual(aux, {})
445441

446-
def test_strategy_compute_eval_loss_sft(self):
447-
self._verify_strategy_compute_eval_loss(sft_mode=True)
442+
def test_strategy_ignores_segmentation_zero_tokens(self):
443+
"""Verifies that 0 tokens in targets_segmentation are ignored in loss computation."""
444+
strategy = distillation_utils.CombinedDistillationStrategy(
445+
student_forward_fn=mock.Mock(),
446+
teacher_forward_fn=mock.Mock(),
447+
vocab_size=4,
448+
temperature=1.0,
449+
alpha=0.5,
450+
pad_id=0,
451+
)
452+
453+
# 1. Leverage the targets_segmentation tensor and put a 0 token in between.
454+
# Token 1 is a delimiter (targets_segmentation = 0).
455+
targets = jnp.array([[2, 1, 3]])
456+
targets_segmentation = jnp.array([[1, 0, 1]])
457+
458+
# 2. Create labels with the zeroed out segment delimiter mask.
459+
labels = strategy.create_labels(targets, targets_segmentation=targets_segmentation)
460+
461+
# Student has all predictions incorrect
462+
s_logits = jnp.array(
463+
[
464+
[
465+
[10.0, -10.0, -10.0, -10.0],
466+
[-10.0, 10.0, -10.0, -10.0],
467+
[-10.0, 10.0, -10.0, -10.0],
468+
]
469+
] # correct
470+
)
471+
student_output = distillation_utils.DistillationForwardOutput(logits=s_logits, out_projection_activations=None)
472+
473+
# Teacher perfectly predicts the target for Token 0 and Token 2, and class 1 for Token 1
474+
t_logits = jnp.array([[[-10.0, -10.0, 10.0, -10.0], [10.0, -10.0, -10.0, -10.0], [-10.0, -10.0, -10.0, 10.0]]])
475+
teacher_output = distillation_utils.DistillationForwardOutput(logits=t_logits, out_projection_activations=None)
476+
477+
# 3. Call compute_loss()
478+
_, metrics = strategy.compute_loss(student_output, teacher_output, labels)
479+
480+
# all tokens are predicted incorrect so the loss should be 10*2 since
481+
# token at position 1 should be excluded from the loss
482+
# mean kl_div should also be equal to 20
483+
self.assertTrue(19.0 < metrics["distill/hard_loss"] < 21.0)
484+
self.assertTrue(19.0 < metrics["distill/soft_loss"] < 21.0)
485+
self.assertTrue(19.0 < metrics["distill/kl_div"] < 21.0)
486+
self.assertTrue(metrics["distill/teacher_loss"] == 0.0)
448487

449488
def test_setup_pipeline_grain_enabled(self):
450489
"""Covers _setup_and_restore_input_pipeline when Grain IS detected."""
@@ -515,6 +554,7 @@ def test_eval_step_calls_student_forward(self):
515554
"attention_mask": mock.Mock(),
516555
"decoder_segment_ids": mock.Mock(),
517556
"targets": mock.Mock(),
557+
"targets_segmentation": None,
518558
}
519559
trainer.gen_model_input_fn = mock.Mock(return_value=mock_batch)
520560

@@ -528,7 +568,7 @@ def test_eval_step_calls_student_forward(self):
528568
trainer.strategy.student_forward_fn.return_value = mock_student_output
529569

530570
mock_labels = mock.Mock()
531-
trainer.strategy.labels_fn.return_value = mock_labels
571+
trainer.strategy.create_labels.return_value = mock_labels
532572

533573
mock_loss = mock.Mock()
534574
trainer.strategy.compute_eval_loss.return_value = mock_loss
@@ -557,7 +597,7 @@ def test_eval_step_calls_student_forward(self):
557597
trainer.strategy.teacher_forward_fn.assert_not_called()
558598

559599
# Verify loss computation pipeline
560-
trainer.strategy.labels_fn.assert_called_once_with(mock_batch["targets"])
600+
trainer.strategy.create_labels.assert_called_once_with(mock_batch["targets"], targets_segmentation=None)
561601
trainer.strategy.compute_eval_loss.assert_called_once_with(mock_student_output, mock_labels)
562602

563603
# Verify it returns the correct loss
@@ -643,7 +683,7 @@ def __call__(self, x):
643683
"teacher_output": jnp.array([1.0, 1.0]),
644684
}
645685
trainer.gen_model_input_fn = mock.Mock(return_value=dummy_batch)
646-
trainer.strategy.labels_fn.return_value = None
686+
trainer.strategy.create_labels.return_value = None
647687

648688
# 4. Mock the forward pass to COUNT how many times it executes
649689
# We wrap the actual dummy model execution in a mock to track it.

0 commit comments

Comments
 (0)