1- # Copyright 2023- 2026 Google LLC
1+ # Copyright 2023– 2026 Google LLC
22#
33# Licensed under the Apache License, Version 2.0 (the "License");
44# you may not use this file except in compliance with the License.
55# You may obtain a copy of the License at
66#
7- # http ://www.apache.org/licenses/LICENSE-2.0
7+ # https ://www.apache.org/licenses/LICENSE-2.0
88#
99# Unless required by applicable law or agreed to in writing, software
1010# distributed under the License is distributed on an "AS IS" BASIS,
2222import tensorflow as tf
2323from array_record .python import array_record_module
2424
25+ import abc
2526from typing import Any , Iterator , Optional , List , Callable
2627
2728import flax
@@ -182,21 +183,87 @@ def __next__(self) -> MaxTextTrainingInput:
182183# -----------------------------------------------------------------------------
183184# Distillation Strategy
184185# -----------------------------------------------------------------------------
185- class CombinedDistillationStrategy :
186+
187+
188+ class DistillationStrategy (abc .ABC ):
189+ """Abstract base class for MaxText Distillation Strategies."""
190+
191+ def __init__ (
192+ self , student_forward_fn : Callable , teacher_forward_fn : Callable , vocab_size : int , pad_id : int = 0 , ** kwargs
193+ ):
194+ """Initializes the generic distillation strategy.
195+
196+ Args:
197+ student_forward_fn: Function to compute student model outputs.
198+ teacher_forward_fn: Function to compute teacher model outputs.
199+ vocab_size: The size of the model's vocabulary.
200+ pad_id: The ID used for padding tokens.
201+ """
202+ self .student_forward_fn = student_forward_fn
203+ self .teacher_forward_fn = teacher_forward_fn
204+ self .vocab_size = vocab_size
205+ self .pad_id = pad_id
206+
207+ @abc .abstractmethod
208+ def compute_loss (
209+ self ,
210+ student_output : "DistillationForwardOutput" ,
211+ teacher_output : "DistillationForwardOutput" ,
212+ labels : jax .Array ,
213+ ) -> tuple [jax .Array , dict [str , jax .Array ]]:
214+ """Computes the distillation loss.
215+
216+ Args:
217+ student_output: The forward pass output of the student model.
218+ teacher_output: The forward pass output of the frozen teacher model.
219+ labels: The masked one-hot encoded ground truth labels.
220+
221+ Returns:
222+ A tuple containing the scalar loss and a dictionary of auxiliary metrics
223+ (e.g., {"distill/soft_loss": ..., "distill/total_loss": ...})
224+ """
225+ raise NotImplementedError
226+
227+ @abc .abstractmethod
228+ def compute_eval_loss (
229+ self ,
230+ student_output : "DistillationForwardOutput" ,
231+ labels : jax .Array ,
232+ ) -> tuple [jax .Array , dict [str , jax .Array ]]:
233+ """Computes the evaluation loss (typically just the task loss).
234+
235+ Args:
236+ student_output: The forward pass output of the student model.
237+ labels: The masked one-hot encoded ground truth labels.
238+
239+ Returns:
240+ A tuple containing the scalar loss and an empty (or auxiliary) dict.
241+ """
242+ raise NotImplementedError
243+
244+ @abc .abstractmethod
245+ def create_labels (self , targets : jax .Array , targets_segmentation : Optional [jax .Array ] = None , ** kwargs ) -> jax .Array :
246+ """
247+ Creates labels tensor to compute the loss
248+ """
249+ raise NotImplementedError
250+
251+
252+ class CombinedDistillationStrategy (DistillationStrategy ):
186253 """Strategy that returns detailed metrics for TensorBoard."""
187254
188255 def __init__ (
189256 self ,
190257 student_forward_fn : Callable [..., DistillationForwardOutput ],
191258 teacher_forward_fn : Callable [..., DistillationForwardOutput ],
192- labels_fn : Callable [..., jax . Array ] ,
259+ pad_id : int = 0 ,
193260 temperature : float = 2.0 ,
194261 alpha : float = 0.5 ,
195262 beta_feature : float = 0.0 ,
196263 layer_indices : Optional [List [int ]] = None ,
197264 feature_loss_fn : Callable [[jax .Array , jax .Array ], jax .Array ] | None = None ,
198265 cosine_distance_axis : int | tuple [int , ...] = - 1 ,
199- sft_mode : bool = False ,
266+ vocab_size : int = 0 ,
200267 ):
201268 """Initializes the Combined strategy using tunix logit.LogitStrategy.
202269
@@ -213,9 +280,14 @@ def __init__(
213280 cosine_distance_axis: The axis to use for cosine distance computation if
214281 feature_loss_fn is not provided. Defaults to -1.
215282 """
216- self .student_forward_fn = student_forward_fn
217- self .teacher_forward_fn = teacher_forward_fn
218- self .labels_fn = labels_fn
283+
284+ super ().__init__ (
285+ student_forward_fn = student_forward_fn ,
286+ teacher_forward_fn = teacher_forward_fn ,
287+ vocab_size = vocab_size ,
288+ pad_id = pad_id ,
289+ )
290+
219291 self .temperature = temperature
220292 self .alpha = alpha
221293 self .beta_feature = beta_feature
@@ -226,7 +298,6 @@ def __init__(
226298 self .feature_loss_fn = lambda student_features , teacher_features : jnp .mean (
227299 optax .cosine_distance (student_features , teacher_features , axis = cosine_distance_axis )
228300 )
229- self .sft_mode = sft_mode
230301
231302 def compute_loss (
232303 self ,
@@ -253,10 +324,9 @@ def compute_loss(
253324
254325 log_student_probs_temp = jax .nn .log_softmax (s_logits / self .temperature , axis = - 1 )
255326 teacher_probs_temp = jax .nn .softmax (t_logits / self .temperature , axis = - 1 )
256-
257327 # labels are supposed to have all sft masks applied by this moment
258- labels_mask = jnp .any (labels != 0 , axis = - 1 , keepdims = True ) if self . sft_mode else None
259- mean_mask = jnp .squeeze (labels_mask , axis = - 1 ) if labels_mask is not None else None
328+ labels_mask = jnp .any (labels != 0 , axis = - 1 , keepdims = True )
329+ mean_mask = jnp .squeeze (labels_mask , axis = - 1 )
260330
261331 # KL(Teacher || Student)
262332 kl_div = optax .kl_divergence (log_student_probs_temp , teacher_probs_temp , where = labels_mask )
@@ -297,7 +367,7 @@ def compute_loss(
297367 metrics = {
298368 "distill/soft_loss" : soft_loss ,
299369 "distill/hard_loss" : hard_loss ,
300- "distill/kl_div" : jnp .mean (kl_div ),
370+ "distill/kl_div" : jnp .mean (kl_div , where = mean_mask ),
301371 "distill/teacher_loss" : teacher_hard_loss ,
302372 "distill/out_proj_feature_loss" : feature_loss ,
303373 "distill/total_loss" : total_loss ,
@@ -316,12 +386,24 @@ def compute_eval_loss(
316386 # Parent logic for task loss
317387 # We re-implement simple CE here to ensure float32 casting
318388 s_logits = student_output .logits .astype (jnp .float32 )
319- ce_loss = optax .softmax_cross_entropy (logits = s_logits , labels = labels )
320- task_loss = jnp .mean (ce_loss )
389+
390+ labels_mask = jnp .any (labels != 0 , axis = - 1 , keepdims = True )
391+ mean_mask = jnp .squeeze (labels_mask , axis = - 1 )
392+ ce_loss = optax .softmax_cross_entropy (logits = s_logits , labels = labels , where = labels_mask )
393+ task_loss = jnp .mean (ce_loss , where = mean_mask )
321394
322395 # Must return a tuple because _has_aux=True expects it
323396 return task_loss , {}
324397
398+ def create_labels (self , targets , targets_segmentation = None , ** kwargs ):
399+ """Converts integer targets to masked one-hot vectors for hard label loss."""
400+ del kwargs # Unused
401+ one_hot = jax .nn .one_hot (targets , self .vocab_size )
402+ mask = jnp .not_equal (targets , self .pad_id ).astype (one_hot .dtype )[..., None ]
403+ if targets_segmentation is not None :
404+ mask = mask * (targets_segmentation != 0 )[..., None ]
405+ return one_hot * mask
406+
325407
326408# -----------------------------------------------------------------------------
327409# Checkpoint Manager
0 commit comments