1818model structures with Tunix's training interfaces.
1919"""
2020
21+ < << << << Updated upstream
2122import pickle
2223import tensorflow as tf
2324from array_record .python import array_record_module
2425
26+ == == == =
27+ import abc
28+ > >> >> >> Stashed changes
2529from typing import Any , Iterator , Optional , List , Callable
2630
2731import flax
@@ -182,21 +186,87 @@ def __next__(self) -> MaxTextTrainingInput:
182186# -----------------------------------------------------------------------------
183187# Distillation Strategy
184188# -----------------------------------------------------------------------------
185- class CombinedDistillationStrategy :
189+
190+
191+ class DistillationStrategy (abc .ABC ):
192+ """Abstract base class for MaxText Distillation Strategies."""
193+
194+ def __init__ (
195+ self , student_forward_fn : Callable , teacher_forward_fn : Callable , vocab_size : int , pad_id : int = 0 , ** kwargs
196+ ):
197+ """Initializes the generic distillation strategy.
198+
199+ Args:
200+ student_forward_fn: Function to compute student model outputs.
201+ teacher_forward_fn: Function to compute teacher model outputs.
202+ vocab_size: The size of the model's vocabulary.
203+ pad_id: The ID used for padding tokens.
204+ """
205+ self .student_forward_fn = student_forward_fn
206+ self .teacher_forward_fn = teacher_forward_fn
207+ self .vocab_size = vocab_size
208+ self .pad_id = pad_id
209+
210+ @abc .abstractmethod
211+ def compute_loss (
212+ self ,
213+ student_output : "DistillationForwardOutput" ,
214+ teacher_output : "DistillationForwardOutput" ,
215+ labels : jax .Array ,
216+ ) -> tuple [jax .Array , dict [str , jax .Array ]]:
217+ """Computes the distillation loss.
218+
219+ Args:
220+ student_output: The forward pass output of the student model.
221+ teacher_output: The forward pass output of the frozen teacher model.
222+ labels: The masked one-hot encoded ground truth labels.
223+
224+ Returns:
225+ A tuple containing the scalar loss and a dictionary of auxiliary metrics
226+ (e.g., {"distill/soft_loss": ..., "distill/total_loss": ...})
227+ """
228+ raise NotImplementedError
229+
230+ @abc .abstractmethod
231+ def compute_eval_loss (
232+ self ,
233+ student_output : "DistillationForwardOutput" ,
234+ labels : jax .Array ,
235+ ) -> tuple [jax .Array , dict [str , jax .Array ]]:
236+ """Computes the evaluation loss (typically just the task loss).
237+
238+ Args:
239+ student_output: The forward pass output of the student model.
240+ labels: The masked one-hot encoded ground truth labels.
241+
242+ Returns:
243+ A tuple containing the scalar loss and an empty (or auxiliary) dict.
244+ """
245+ raise NotImplementedError
246+
247+ @abc .abstractmethod
248+ def create_labels (self , targets : jax .Array , targets_segmentation : Optional [jax .Array ] = None , ** kwargs ) -> jax .Array :
249+ """
250+ Creates labels tensor to compute the loss
251+ """
252+ raise NotImplementedError
253+
254+
255+ class CombinedDistillationStrategy (DistillationStrategy ):
186256 """Strategy that returns detailed metrics for TensorBoard."""
187257
188258 def __init__ (
189259 self ,
190260 student_forward_fn : Callable [..., DistillationForwardOutput ],
191261 teacher_forward_fn : Callable [..., DistillationForwardOutput ],
192- labels_fn : Callable [..., jax . Array ] ,
262+ pad_id : int = 0 ,
193263 temperature : float = 2.0 ,
194264 alpha : float = 0.5 ,
195265 beta_feature : float = 0.0 ,
196266 layer_indices : Optional [List [int ]] = None ,
197267 feature_loss_fn : Callable [[jax .Array , jax .Array ], jax .Array ] | None = None ,
198268 cosine_distance_axis : int | tuple [int , ...] = - 1 ,
199- sft_mode : bool = False ,
269+ vocab_size : int = 0 ,
200270 ):
201271 """Initializes the Combined strategy using tunix logit.LogitStrategy.
202272
@@ -213,9 +283,14 @@ def __init__(
213283 cosine_distance_axis: The axis to use for cosine distance computation if
214284 feature_loss_fn is not provided. Defaults to -1.
215285 """
216- self .student_forward_fn = student_forward_fn
217- self .teacher_forward_fn = teacher_forward_fn
218- self .labels_fn = labels_fn
286+
287+ super ().__init__ (
288+ student_forward_fn = student_forward_fn ,
289+ teacher_forward_fn = teacher_forward_fn ,
290+ vocab_size = vocab_size ,
291+ pad_id = pad_id ,
292+ )
293+
219294 self .temperature = temperature
220295 self .alpha = alpha
221296 self .beta_feature = beta_feature
@@ -226,7 +301,6 @@ def __init__(
226301 self .feature_loss_fn = lambda student_features , teacher_features : jnp .mean (
227302 optax .cosine_distance (student_features , teacher_features , axis = cosine_distance_axis )
228303 )
229- self .sft_mode = sft_mode
230304
231305 def compute_loss (
232306 self ,
@@ -253,10 +327,9 @@ def compute_loss(
253327
254328 log_student_probs_temp = jax .nn .log_softmax (s_logits / self .temperature , axis = - 1 )
255329 teacher_probs_temp = jax .nn .softmax (t_logits / self .temperature , axis = - 1 )
256-
257330 # labels are supposed to have all sft masks applied by this moment
258- labels_mask = jnp .any (labels != 0 , axis = - 1 , keepdims = True ) if self . sft_mode else None
259- mean_mask = jnp .squeeze (labels_mask , axis = - 1 ) if labels_mask is not None else None
331+ labels_mask = jnp .any (labels != 0 , axis = - 1 , keepdims = True )
332+ mean_mask = jnp .squeeze (labels_mask , axis = - 1 )
260333
261334 # KL(Teacher || Student)
262335 kl_div = optax .kl_divergence (log_student_probs_temp , teacher_probs_temp , where = labels_mask )
@@ -297,7 +370,7 @@ def compute_loss(
297370 metrics = {
298371 "distill/soft_loss" : soft_loss ,
299372 "distill/hard_loss" : hard_loss ,
300- "distill/kl_div" : jnp .mean (kl_div ),
373+ "distill/kl_div" : jnp .mean (kl_div , where = mean_mask ),
301374 "distill/teacher_loss" : teacher_hard_loss ,
302375 "distill/out_proj_feature_loss" : feature_loss ,
303376 "distill/total_loss" : total_loss ,
@@ -316,12 +389,24 @@ def compute_eval_loss(
316389 # Parent logic for task loss
317390 # We re-implement simple CE here to ensure float32 casting
318391 s_logits = student_output .logits .astype (jnp .float32 )
319- ce_loss = optax .softmax_cross_entropy (logits = s_logits , labels = labels )
320- task_loss = jnp .mean (ce_loss )
392+
393+ labels_mask = jnp .any (labels != 0 , axis = - 1 , keepdims = True )
394+ mean_mask = jnp .squeeze (labels_mask , axis = - 1 )
395+ ce_loss = optax .softmax_cross_entropy (logits = s_logits , labels = labels , where = labels_mask )
396+ task_loss = jnp .mean (ce_loss , where = mean_mask )
321397
322398 # Must return a tuple because _has_aux=True expects it
323399 return task_loss , {}
324400
401+ def create_labels (self , targets , targets_segmentation = None , ** kwargs ):
402+ """Converts integer targets to masked one-hot vectors for hard label loss."""
403+ del kwargs # Unused
404+ one_hot = jax .nn .one_hot (targets , self .vocab_size )
405+ mask = jnp .not_equal (targets , self .pad_id ).astype (one_hot .dtype )[..., None ]
406+ if targets_segmentation is not None :
407+ mask = mask * (targets_segmentation != 0 )[..., None ]
408+ return one_hot * mask
409+
325410
326411# -----------------------------------------------------------------------------
327412# Checkpoint Manager
0 commit comments