@@ -185,6 +185,38 @@ def __next__(self) -> MaxTextTrainingInput:
185185# -----------------------------------------------------------------------------
186186
187187
188+ def compute_schedule (
189+ step : jax .Array ,
190+ max_steps : int ,
191+ start_value : float ,
192+ end_value : float | None ,
193+ schedule_type : str ,
194+ ) -> jax .Array :
195+ """Computes a scheduled value based on training progress.
196+
197+ Args:
198+ step: Current training step as a JAX array.
199+ max_steps: Total number of training steps.
200+ start_value: Value at the beginning of training.
201+ end_value: Value at the end of training. If None, returns start_value.
202+ schedule_type: One of "constant", "linear", or "cosine".
203+
204+ Returns:
205+ The scheduled value as a JAX scalar.
206+ """
207+ if end_value is None or schedule_type == "constant" :
208+ return jnp .array (start_value , dtype = jnp .float32 )
209+
210+ progress = jnp .clip (step .astype (jnp .float32 ) / max_steps , 0.0 , 1.0 )
211+
212+ if schedule_type == "linear" :
213+ return start_value + (end_value - start_value ) * progress
214+ elif schedule_type == "cosine" :
215+ return end_value + (start_value - end_value ) * 0.5 * (1.0 + jnp .cos (jnp .pi * progress ))
216+ else :
217+ raise ValueError (f"Unsupported schedule_type: { schedule_type !r} . Must be 'constant', 'linear', or 'cosine'." )
218+
219+
188220class DistillationStrategy (abc .ABC ):
189221 """Abstract base class for MaxText Distillation Strategies."""
190222
@@ -210,13 +242,15 @@ def compute_loss(
210242 student_output : "DistillationForwardOutput" ,
211243 teacher_output : "DistillationForwardOutput" ,
212244 labels : jax .Array ,
245+ step : jax .Array | None = None ,
213246 ) -> tuple [jax .Array , dict [str , jax .Array ]]:
214247 """Computes the distillation loss.
215248
216249 Args:
217250 student_output: The forward pass output of the student model.
218251 teacher_output: The forward pass output of the frozen teacher model.
219252 labels: The masked one-hot encoded ground truth labels.
253+ step: Current training step for dynamic scheduling. If None, uses fixed values.
220254
221255 Returns:
222256 A tuple containing the scalar loss and a dictionary of auxiliary metrics
@@ -265,8 +299,15 @@ def __init__(
265299 feature_loss_type : Literal ["cosine" , "l2" ] = "cosine" ,
266300 cosine_distance_axis : int | tuple [int , ...] = - 1 ,
267301 vocab_size : int = 0 ,
302+ alpha_end : float | None = None ,
303+ alpha_schedule : Literal ["constant" , "linear" , "cosine" ] = "constant" ,
304+ temperature_end : float | None = None ,
305+ temperature_schedule : Literal ["constant" , "linear" , "cosine" ] = "constant" ,
306+ beta_end : float | None = None ,
307+ beta_schedule : Literal ["constant" , "linear" , "cosine" ] = "constant" ,
308+ max_steps : int = 1 ,
268309 ):
269- """Initializes the Combined strategy using tunix logit.LogitStrategy .
310+ """Initializes the Combined distillation strategy .
270311
271312 Args:
272313 student_forward_fn: Function to compute student model outputs.
@@ -282,6 +323,13 @@ def __init__(
282323 teacher_map) and returns a scalar loss. Defaults to Cosine Distance.
283324 cosine_distance_axis: The axis to use for cosine distance computation if
284325 feature_loss_fn is not provided. Defaults to -1.
326+ alpha_end: Target alpha value at end of training. None keeps alpha fixed.
327+ alpha_schedule: Schedule type for alpha annealing.
328+ temperature_end: Target temperature at end of training. None keeps temperature fixed.
329+ temperature_schedule: Schedule type for temperature annealing.
330+ beta_end: Target beta_feature value at end of training. None keeps beta fixed.
331+ beta_schedule: Schedule type for beta annealing.
332+ max_steps: Total training steps, used for schedule computation.
285333 """
286334
287335 super ().__init__ (
@@ -296,6 +344,39 @@ def __init__(
296344 self .beta_feature = beta_feature
297345 self .layer_indices = jnp .array (layer_indices ) if layer_indices is not None else None
298346
347+ # Schedule parameters
348+ self .alpha_end = alpha_end
349+ self .alpha_schedule = alpha_schedule
350+ self .temperature_end = temperature_end
351+ self .temperature_schedule = temperature_schedule
352+ self .beta_end = beta_end
353+ self .beta_schedule = beta_schedule
354+ self .max_steps = max_steps
355+
356+ # Validate schedule parameter ranges
357+ if alpha_end is not None and not 0.0 <= alpha_end <= 1.0 :
358+ raise ValueError (f"alpha_end must be in [0, 1], got { alpha_end } " )
359+ if temperature_end is not None and temperature_end <= 0.0 :
360+ raise ValueError (f"temperature_end must be > 0, got { temperature_end } " )
361+ if beta_end is not None and beta_end < 0.0 :
362+ raise ValueError (f"beta_end must be >= 0, got { beta_end } " )
363+ if beta_feature == 0.0 and beta_end is not None and beta_end > 0.0 :
364+ raise ValueError (
365+ f"distill_beta=0.0 but distill_beta_end={ beta_end } . Feature extraction is disabled when "
366+ "distill_beta starts at 0.0 (the model does not sow intermediate activations). "
367+ "Set distill_beta to a small positive value (e.g., 1e-6) to enable feature extraction."
368+ )
369+ for param_name , schedule , end_value in [
370+ ("alpha" , alpha_schedule , alpha_end ),
371+ ("temperature" , temperature_schedule , temperature_end ),
372+ ("beta" , beta_schedule , beta_end ),
373+ ]:
374+ if schedule != "constant" and end_value is None :
375+ raise ValueError (
376+ f"{ param_name } _schedule is '{ schedule } ' but { param_name } _end is None. "
377+ f"Set { param_name } _end to a target value or use schedule='constant'."
378+ )
379+
299380 self .feature_loss_fn = feature_loss_fn
300381 if feature_loss_fn is None :
301382 if feature_loss_type == "cosine" :
@@ -309,13 +390,39 @@ def __init__(
309390 else :
310391 raise ValueError (f"Unsupported feature_loss_type: { feature_loss_type !r} " )
311392
393+ def _get_scheduled_weights (self , step : jax .Array | None ) -> tuple [jax .Array , jax .Array , jax .Array ]:
394+ """Resolves the current alpha, temperature, and beta values from schedules.
395+
396+ Args:
397+ step: Current training step. If None, returns the fixed initial values.
398+
399+ Returns:
400+ A tuple of (alpha, temperature, beta_feature) as JAX scalars.
401+ """
402+ if step is None :
403+ return (
404+ jnp .array (self .alpha , dtype = jnp .float32 ),
405+ jnp .array (self .temperature , dtype = jnp .float32 ),
406+ jnp .array (self .beta_feature , dtype = jnp .float32 ),
407+ )
408+ alpha = compute_schedule (step , self .max_steps , self .alpha , self .alpha_end , self .alpha_schedule )
409+ temperature = compute_schedule (
410+ step , self .max_steps , self .temperature , self .temperature_end , self .temperature_schedule
411+ )
412+ beta_feature = compute_schedule (step , self .max_steps , self .beta_feature , self .beta_end , self .beta_schedule )
413+ return alpha , temperature , beta_feature
414+
312415 def compute_loss (
313416 self ,
314417 student_output : DistillationForwardOutput ,
315418 teacher_output : DistillationForwardOutput ,
316419 labels : jax .Array ,
420+ step : jax .Array | None = None ,
317421 ) -> tuple [jax .Array , dict [str , jax .Array ]]:
318422 """Computes Loss and Auxiliary Metrics."""
423+ # Resolve scheduled weights for this step
424+ alpha , temperature , beta_feature = self ._get_scheduled_weights (step )
425+
319426 # Calculate Distillation Loss (KL Divergence)
320427 # Scale logits by temperature T for soft targets
321428 # We use explicit float32 casting for stability in loss calculation
@@ -332,8 +439,8 @@ def compute_loss(
332439 "Ensure the model architecture supports feature extraction (e.g., 'out_projection_activations' is sowed)."
333440 )
334441
335- log_student_probs_temp = jax .nn .log_softmax (s_logits / self . temperature , axis = - 1 )
336- teacher_probs_temp = jax .nn .softmax (t_logits / self . temperature , axis = - 1 )
442+ log_student_probs_temp = jax .nn .log_softmax (s_logits / temperature , axis = - 1 )
443+ teacher_probs_temp = jax .nn .softmax (t_logits / temperature , axis = - 1 )
337444 # labels are supposed to have all sft masks applied by this moment
338445 labels_mask = jnp .any (labels != 0 , axis = - 1 , keepdims = True )
339446 mean_mask = jnp .squeeze (labels_mask , axis = - 1 )
@@ -342,7 +449,7 @@ def compute_loss(
342449 kl_div = optax .kl_divergence (log_student_probs_temp , teacher_probs_temp , where = labels_mask )
343450
344451 # Scale gradients by T^2 (Hinton et al.)
345- soft_loss = jnp .mean (kl_div , where = mean_mask ) * (self . temperature ** 2 )
452+ soft_loss = jnp .mean (kl_div , where = mean_mask ) * (temperature ** 2 )
346453
347454 # 1. Student Hard Loss (Existing)
348455 ce_loss_student = optax .softmax_cross_entropy (logits = s_logits , labels = labels , where = labels_mask )
@@ -353,7 +460,7 @@ def compute_loss(
353460 teacher_hard_loss = jnp .mean (ce_loss_teacher , where = mean_mask )
354461
355462 # 3. Combine losses
356- base_logit_loss = (self . alpha * soft_loss ) + ((1.0 - self . alpha ) * hard_loss )
463+ base_logit_loss = (alpha * soft_loss ) + ((1.0 - alpha ) * hard_loss )
357464
358465 feature_loss = jnp .array (0.0 )
359466 if self .beta_feature > 0.0 :
@@ -369,21 +476,21 @@ def compute_loss(
369476 s_features_sliced = s_features_sliced .astype (jnp .float32 )
370477 t_features_sliced = t_features_sliced .astype (jnp .float32 )
371478
372- feature_loss = self . beta_feature * self .feature_loss_fn (s_features_sliced , t_features_sliced )
479+ feature_loss = beta_feature * self .feature_loss_fn (s_features_sliced , t_features_sliced )
373480
374481 total_loss = base_logit_loss + feature_loss
375482
376- # 4. Return Loss AND Metrics
483+ # 4. Return Loss AND Metrics (log dynamic values for TensorBoard verification)
377484 metrics = {
378485 "distill/soft_loss" : soft_loss ,
379486 "distill/hard_loss" : hard_loss ,
380487 "distill/kl_div" : jnp .mean (kl_div , where = mean_mask ),
381488 "distill/teacher_loss" : teacher_hard_loss ,
382489 "distill/out_proj_feature_loss" : feature_loss ,
383490 "distill/total_loss" : total_loss ,
384- "distill/temperature" : self . temperature ,
385- "distill/alpha" : self . alpha ,
386- "distill/beta_feature" : self . beta_feature ,
491+ "distill/temperature" : temperature ,
492+ "distill/alpha" : alpha ,
493+ "distill/beta_feature" : beta_feature ,
387494 }
388495 return total_loss , metrics
389496
0 commit comments