Skip to content

Commit ce6fce8

Browse files
committed
fixed eod masking + refactored distill strategy
1 parent 9e786c8 commit ce6fce8

3 files changed

Lines changed: 176 additions & 41 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818
model structures with Tunix's training interfaces.
1919
"""
2020

21+
<<<<<<< Updated upstream
2122
import pickle
2223
import tensorflow as tf
2324
from array_record.python import array_record_module
2425

26+
=======
27+
import abc
28+
>>>>>>> Stashed changes
2529
from typing import Any, Iterator, Optional, List, Callable
2630

2731
import flax
@@ -182,21 +186,87 @@ def __next__(self) -> MaxTextTrainingInput:
182186
# -----------------------------------------------------------------------------
183187
# Distillation Strategy
184188
# -----------------------------------------------------------------------------
185-
class CombinedDistillationStrategy:
189+
190+
191+
class DistillationStrategy(abc.ABC):
192+
"""Abstract base class for MaxText Distillation Strategies."""
193+
194+
def __init__(
195+
self, student_forward_fn: Callable, teacher_forward_fn: Callable, vocab_size: int, pad_id: int = 0, **kwargs
196+
):
197+
"""Initializes the generic distillation strategy.
198+
199+
Args:
200+
student_forward_fn: Function to compute student model outputs.
201+
teacher_forward_fn: Function to compute teacher model outputs.
202+
vocab_size: The size of the model's vocabulary.
203+
pad_id: The ID used for padding tokens.
204+
"""
205+
self.student_forward_fn = student_forward_fn
206+
self.teacher_forward_fn = teacher_forward_fn
207+
self.vocab_size = vocab_size
208+
self.pad_id = pad_id
209+
210+
@abc.abstractmethod
211+
def compute_loss(
212+
self,
213+
student_output: "DistillationForwardOutput",
214+
teacher_output: "DistillationForwardOutput",
215+
labels: jax.Array,
216+
) -> tuple[jax.Array, dict[str, jax.Array]]:
217+
"""Computes the distillation loss.
218+
219+
Args:
220+
student_output: The forward pass output of the student model.
221+
teacher_output: The forward pass output of the frozen teacher model.
222+
labels: The masked one-hot encoded ground truth labels.
223+
224+
Returns:
225+
A tuple containing the scalar loss and a dictionary of auxiliary metrics
226+
(e.g., {"distill/soft_loss": ..., "distill/total_loss": ...})
227+
"""
228+
raise NotImplementedError
229+
230+
@abc.abstractmethod
231+
def compute_eval_loss(
232+
self,
233+
student_output: "DistillationForwardOutput",
234+
labels: jax.Array,
235+
) -> tuple[jax.Array, dict[str, jax.Array]]:
236+
"""Computes the evaluation loss (typically just the task loss).
237+
238+
Args:
239+
student_output: The forward pass output of the student model.
240+
labels: The masked one-hot encoded ground truth labels.
241+
242+
Returns:
243+
A tuple containing the scalar loss and an empty (or auxiliary) dict.
244+
"""
245+
raise NotImplementedError
246+
247+
@abc.abstractmethod
248+
def create_labels(self, targets: jax.Array, targets_segmentation: Optional[jax.Array] = None, **kwargs) -> jax.Array:
249+
"""
250+
Creates labels tensor to compute the loss
251+
"""
252+
raise NotImplementedError
253+
254+
255+
class CombinedDistillationStrategy(DistillationStrategy):
186256
"""Strategy that returns detailed metrics for TensorBoard."""
187257

188258
def __init__(
189259
self,
190260
student_forward_fn: Callable[..., DistillationForwardOutput],
191261
teacher_forward_fn: Callable[..., DistillationForwardOutput],
192-
labels_fn: Callable[..., jax.Array],
262+
pad_id: int = 0,
193263
temperature: float = 2.0,
194264
alpha: float = 0.5,
195265
beta_feature: float = 0.0,
196266
layer_indices: Optional[List[int]] = None,
197267
feature_loss_fn: Callable[[jax.Array, jax.Array], jax.Array] | None = None,
198268
cosine_distance_axis: int | tuple[int, ...] = -1,
199-
sft_mode: bool = False,
269+
vocab_size: int = 0,
200270
):
201271
"""Initializes the Combined strategy using tunix logit.LogitStrategy.
202272
@@ -213,9 +283,14 @@ def __init__(
213283
cosine_distance_axis: The axis to use for cosine distance computation if
214284
feature_loss_fn is not provided. Defaults to -1.
215285
"""
216-
self.student_forward_fn = student_forward_fn
217-
self.teacher_forward_fn = teacher_forward_fn
218-
self.labels_fn = labels_fn
286+
287+
super().__init__(
288+
student_forward_fn=student_forward_fn,
289+
teacher_forward_fn=teacher_forward_fn,
290+
vocab_size=vocab_size,
291+
pad_id=pad_id,
292+
)
293+
219294
self.temperature = temperature
220295
self.alpha = alpha
221296
self.beta_feature = beta_feature
@@ -226,7 +301,6 @@ def __init__(
226301
self.feature_loss_fn = lambda student_features, teacher_features: jnp.mean(
227302
optax.cosine_distance(student_features, teacher_features, axis=cosine_distance_axis)
228303
)
229-
self.sft_mode = sft_mode
230304

231305
def compute_loss(
232306
self,
@@ -253,10 +327,9 @@ def compute_loss(
253327

254328
log_student_probs_temp = jax.nn.log_softmax(s_logits / self.temperature, axis=-1)
255329
teacher_probs_temp = jax.nn.softmax(t_logits / self.temperature, axis=-1)
256-
257330
# labels are supposed to have all sft masks applied by this moment
258-
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True) if self.sft_mode else None
259-
mean_mask = jnp.squeeze(labels_mask, axis=-1) if labels_mask is not None else None
331+
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
332+
mean_mask = jnp.squeeze(labels_mask, axis=-1)
260333

261334
# KL(Teacher || Student)
262335
kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp, where=labels_mask)
@@ -297,7 +370,7 @@ def compute_loss(
297370
metrics = {
298371
"distill/soft_loss": soft_loss,
299372
"distill/hard_loss": hard_loss,
300-
"distill/kl_div": jnp.mean(kl_div),
373+
"distill/kl_div": jnp.mean(kl_div, where=mean_mask),
301374
"distill/teacher_loss": teacher_hard_loss,
302375
"distill/out_proj_feature_loss": feature_loss,
303376
"distill/total_loss": total_loss,
@@ -316,12 +389,24 @@ def compute_eval_loss(
316389
# Parent logic for task loss
317390
# We re-implement simple CE here to ensure float32 casting
318391
s_logits = student_output.logits.astype(jnp.float32)
319-
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels)
320-
task_loss = jnp.mean(ce_loss)
392+
393+
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
394+
mean_mask = jnp.squeeze(labels_mask, axis=-1)
395+
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels, where=labels_mask)
396+
task_loss = jnp.mean(ce_loss, where=mean_mask)
321397

322398
# Must return a tuple because _has_aux=True expects it
323399
return task_loss, {}
324400

401+
def create_labels(self, targets, targets_segmentation=None, **kwargs):
402+
"""Converts integer targets to masked one-hot vectors for hard label loss."""
403+
del kwargs # Unused
404+
one_hot = jax.nn.one_hot(targets, self.vocab_size)
405+
mask = jnp.not_equal(targets, self.pad_id).astype(one_hot.dtype)[..., None]
406+
if targets_segmentation is not None:
407+
mask = mask * (targets_segmentation != 0)[..., None]
408+
return one_hot * mask
409+
325410

326411
# -----------------------------------------------------------------------------
327412
# Checkpoint Manager

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ class MaxTextDistillationTrainer(peft_trainer.PeftTrainer):
197197
(positions, segment_ids) are passed to the model.
198198
"""
199199

200-
def __init__(self, model, strategy, optimizer, training_config, **kwargs):
200+
def __init__(self, model, strategy: distillation_utils.DistillationStrategy, optimizer, training_config, **kwargs):
201201
# We pass a dummy optimizer to the base PeftTrainer temporarily to prevent PeftTrainer from eagerly
202202
# allocating massive optimizer states for the entire ModelBundle (including the frozen teacher) before
203203
# redefining the trainer optimizer here.
@@ -245,7 +245,7 @@ def loss_wrapper(student, teacher, batch):
245245
cache=None,
246246
)
247247
# we should apply a mask for labels to disable segment-separator tokens
248-
labels = self.strategy.labels_fn(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
248+
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
249249
return self.strategy.compute_loss(student_output, teacher_output, labels)
250250

251251
# Because student is the 0th argument, argnums=0 guarantees
@@ -274,7 +274,7 @@ def _eval_step(self, model, inputs):
274274
decoder_segment_ids=inputs.get("decoder_segment_ids"),
275275
cache=None,
276276
)
277-
labels = self.strategy.labels_fn(inputs["targets"])
277+
labels = self.strategy.create_labels(inputs["targets"], targets_segmentation=inputs.get("targets_segmentation", None))
278278
return self.strategy.compute_eval_loss(student_output, labels)
279279

280280
def _prepare_inputs(
@@ -454,14 +454,6 @@ def train_distill(
454454
teacher_model.eval()
455455

456456
# 3. Define Distillation Strategy
457-
def labels_fn(targets, targets_segmentation=None, **kwargs):
458-
"""Converts integer targets to masked one-hot vectors for hard label loss."""
459-
del kwargs # Unused
460-
one_hot = jax.nn.one_hot(targets, student_config.vocab_size)
461-
mask = jnp.not_equal(targets, pad_id).astype(one_hot.dtype)[..., None]
462-
if targets_segmentation is not None:
463-
mask = mask * (targets_segmentation != 0)[..., None]
464-
return one_hot * mask
465457

466458
# Both Student and Teacher use the same forward logic via the adapter
467459
student_forward_fn = create_forward_fn(student_config)
@@ -471,12 +463,11 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
471463
strategy = distillation_utils.CombinedDistillationStrategy(
472464
student_forward_fn=student_forward_fn,
473465
teacher_forward_fn=teacher_forward_fn,
474-
labels_fn=labels_fn,
475466
temperature=student_config.distill_temperature,
476467
alpha=student_config.distill_alpha,
477468
beta_feature=student_config.distill_beta,
478469
layer_indices=student_config.distill_layer_indices,
479-
sft_mode=student_config.use_sft,
470+
vocab_size=student_config.vocab_size,
480471
)
481472

482473
# 4. Optimizer & Config
@@ -539,6 +530,7 @@ def labels_fn(targets, targets_segmentation=None, **kwargs):
539530
raw_train_iter = _setup_and_restore_input_pipeline(trainer, raw_train_iter, student_config, train_config)
540531

541532
# 8. Configure Input Mapping
533+
<<<<<<< Updated upstream
542534
def custom_gen_model_input_fn(batch):
543535
inputs_dict = {
544536
"input_tokens": batch.input_tokens,
@@ -568,6 +560,20 @@ def custom_gen_model_input_fn(batch):
568560
return inputs_dict
569561

570562
trainer = trainer.with_gen_model_input_fn(custom_gen_model_input_fn)
563+
=======
564+
trainer = trainer.with_gen_model_input_fn(
565+
lambda batch: {
566+
"input_tokens": batch.input_tokens,
567+
"positions": batch.positions,
568+
"attention_mask": batch.input_mask,
569+
"decoder_segment_ids": batch.decoder_segment_ids,
570+
"targets": batch.targets, # Passed to strategy (create_labels)
571+
"targets_position": batch.targets_position, # Passed to strategy (create_labels)
572+
"targets_segmentation": batch.targets_segmentation, # Passed to strategy (create_labels)
573+
"cache": None,
574+
}
575+
)
576+
>>>>>>> Stashed changes
571577

572578
# 9. Create Iterator Wrappers (Use Utils)
573579
train_iter = distillation_utils.MaxTextToTunixIterator(raw_train_iter)

0 commit comments

Comments
 (0)