1818model structures with Tunix's training interfaces.
1919"""
2020
21- import pickle
22- import tensorflow as tf
23- from array_record .python import array_record_module
24-
2521import abc
26- from typing import Any , Iterator , Optional , List , Callable , Literal
22+ import pickle
23+ from typing import Any , Callable , Iterator , List , Literal , Optional , Sequence
2724
2825import flax
2926from flax import nnx
3027import jax
3128import jax .numpy as jnp
29+ import numpy as np
3230import optax
31+ import tensorflow as tf
32+ from array_record .python import array_record_module
3333from orbax import checkpoint
3434
3535from maxtext .utils import max_logging
3636# Reuse MaxText's native checkpointing logic
3737from maxtext .common .checkpointing import GrainCheckpointHandler , GrainCheckpointSave , GrainCheckpointRestore
38- from tunix .sft import peft_trainer
3938from tunix .sft import checkpoint_manager as tunix_checkpoint_manager
39+ from tunix .sft import peft_trainer
4040
4141
4242# -----------------------------------------------------------------------------
@@ -184,6 +184,11 @@ def __next__(self) -> MaxTextTrainingInput:
184184# Distillation Strategy
185185# -----------------------------------------------------------------------------
186186
187+ # Clamp CE before exp() so a divergence spike doesn't poison PPL averages
188+ # with inf. 20 nats is well above plausible CE (Llama random-init ~11.76)
189+ # and far below fp32 exp overflow (~88).
190+ _PPL_CE_CAP = 20.0
191+
187192
188193def compute_schedule (
189194 step : jax .Array ,
@@ -217,6 +222,33 @@ def compute_schedule(
217222 raise ValueError (f"Unsupported schedule_type: { schedule_type !r} . Must be 'constant', 'linear', or 'cosine'." )
218223
219224
225+ def weighted_mean (sum_count_pairs : Sequence [tuple [Any , Any ]] | np .ndarray ) -> float :
226+ """Aggregates `(sum, count)` pairs into a single token-weighted mean.
227+
228+ Used as the aggregation function for metrics emitted by `compute_loss` and
229+ `compute_eval_loss`. Robust to per-host imbalance and to varying valid-token
230+ counts across logging steps:
231+ final_value = sum(sums) / sum(counts)
232+
233+ Accepts either a list of (sum, count) tuples or an ndarray of shape (N, 2).
234+ Tunix's metrics writer can pass either form, so we normalize here.
235+
236+ Returns 0.0 for an empty input or when total count is non-positive.
237+ """
238+ arr = np .asarray (sum_count_pairs , dtype = np .float32 )
239+ if arr .size == 0 :
240+ return 0.0
241+ # Normalize shape. Single pair -> (1, 2); list of pairs -> (N, 2).
242+ if arr .ndim == 1 :
243+ arr = arr .reshape (1 , - 1 )
244+ if arr .ndim != 2 or arr .shape [1 ] != 2 :
245+ return 0.0
246+ total = float (arr [:, 1 ].sum ())
247+ if total <= 0.0 :
248+ return 0.0
249+ return float (arr [:, 0 ].sum () / total )
250+
251+
220252class DistillationStrategy (abc .ABC ):
221253 """Abstract base class for MaxText Distillation Strategies."""
222254
@@ -312,15 +344,14 @@ def __init__(
312344 Args:
313345 student_forward_fn: Function to compute student model outputs.
314346 teacher_forward_fn: Function to compute teacher model outputs.
315- labels_fn: Function to compute labels from model inputs.
316347 temperature: Temperature for softening probabilities (> 0).
317348 alpha: Weight to balance distillation loss and task loss (0.0 to 1.0).
318349 beta_feature: Weight to balance feature loss (0.0 to 1.0). 0.0 disables feature loss.
319350 layer_indices: Layer indices to apply feature loss.
320351 feature_loss_type: The type of feature loss to use if `feature_loss_fn` is None.
321352 Can be "cosine" (default) or "l2".
322- feature_loss_fn: A function that takes two jax. Arrays (student_map,
323- teacher_map) and returns a scalar loss. Defaults to Cosine Distance .
353+ feature_loss_fn: A function that takes two jax.Arrays (student_map,
354+ teacher_map) and returns a scalar loss. Defaults to cosine distance .
324355 cosine_distance_axis: The axis to use for cosine distance computation if
325356 feature_loss_fn is not provided. Defaults to -1.
326357 alpha_end: Target alpha value at end of training. None keeps alpha fixed.
@@ -418,18 +449,19 @@ def compute_loss(
418449 teacher_output : DistillationForwardOutput ,
419450 labels : jax .Array ,
420451 step : jax .Array | None = None ,
421- ) -> tuple [jax .Array , dict [str , jax .Array ]]:
422- """Computes Loss and Auxiliary Metrics."""
452+ ) -> tuple [jax .Array , dict [str , tuple [jax .Array , jax .Array ]]]:
453+ """Computes Loss and Auxiliary Metrics.
454+
455+ Metrics are emitted as (sum, count) pairs so that they can be aggregated
456+ across hosts and across logging windows in a token-weighted (unbiased) way:
457+ final_value = sum(sums) / sum(counts).
458+ """
423459 # Resolve scheduled weights for this step
424460 alpha , temperature , beta_feature = self ._get_scheduled_weights (step )
425461
426- # Calculate Distillation Loss (KL Divergence)
427- # Scale logits by temperature T for soft targets
428- # We use explicit float32 casting for stability in loss calculation
429462 s_logits = student_output .logits .astype (jnp .float32 )
430463 t_logits = teacher_output .logits .astype (jnp .float32 )
431464
432- # Shape: [num_layers, batch, seq, hidden_dim]
433465 s_features = student_output .out_projection_activations
434466 t_features = teacher_output .out_projection_activations
435467
@@ -439,34 +471,42 @@ def compute_loss(
439471 "Ensure the model architecture supports feature extraction (e.g., 'out_projection_activations' is sowed)."
440472 )
441473
442- log_student_probs_temp = jax .nn .log_softmax (s_logits / temperature , axis = - 1 )
443- teacher_probs_temp = jax .nn .softmax (t_logits / temperature , axis = - 1 )
444- # labels are supposed to have all sft masks applied by this moment
445- labels_mask = jnp .any (labels != 0 , axis = - 1 , keepdims = True )
446- mean_mask = jnp .squeeze (labels_mask , axis = - 1 )
447-
448- # KL(Teacher || Student)
449- kl_div = optax .kl_divergence (log_student_probs_temp , teacher_probs_temp , where = labels_mask )
450-
451- # Scale gradients by T^2 (Hinton et al.)
452- soft_loss = jnp .mean (kl_div , where = mean_mask ) * (temperature ** 2 )
453-
454- # 1. Student Hard Loss (Existing)
455- ce_loss_student = optax .softmax_cross_entropy (logits = s_logits , labels = labels , where = labels_mask )
456- hard_loss = jnp .mean (ce_loss_student , where = mean_mask )
457-
458- # 2. Teacher Hard Loss (For Verification)
459- ce_loss_teacher = optax .softmax_cross_entropy (logits = t_logits , labels = labels , where = labels_mask )
460- teacher_hard_loss = jnp .mean (ce_loss_teacher , where = mean_mask )
461-
462- # 3. Combine losses
463- base_logit_loss = (alpha * soft_loss ) + ((1.0 - alpha ) * hard_loss )
464-
465- feature_loss = jnp .array (0.0 )
474+ # Per-token validity mask, derived from the one-hot labels so we don't need
475+ # a separate mask input. A padded (fully-zero) row yields `any != 0 == False`.
476+ mask = jnp .any (labels != 0 , axis = - 1 ).astype (jnp .float32 ) # [B, T]
477+ valid_count = jnp .sum (mask )
478+ safe_count = jnp .maximum (valid_count , 1.0 )
479+
480+ # --- Soft loss: KL on temperature-softened distributions ---
481+ log_s_T = jax .nn .log_softmax (s_logits / temperature , axis = - 1 )
482+ t_p_T = jax .nn .softmax (t_logits / temperature , axis = - 1 )
483+ # KL(teacher || student) per position. optax.kl_divergence(log_pred, target) = KL(target || pred).
484+ kl_softened_per_pos = optax .kl_divergence (log_s_T , t_p_T ) # [B, T]
485+ kl_softened_sum = jnp .sum (kl_softened_per_pos * mask )
486+ # Scale by T^2 (Hinton). Apply once at the loss; logged metric is the scaled sum too.
487+ soft_loss_sum_scaled = kl_softened_sum * (temperature ** 2 )
488+ soft_loss_mean = soft_loss_sum_scaled / safe_count
489+
490+ # --- Hard loss: student CE against ground-truth ---
491+ ce_student_per_pos = optax .softmax_cross_entropy (logits = s_logits , labels = labels )
492+ ce_student_sum = jnp .sum (ce_student_per_pos * mask )
493+ hard_loss_mean = ce_student_sum / safe_count
494+
495+ # --- Teacher CE (verification metric) ---
496+ ce_teacher_per_pos = optax .softmax_cross_entropy (logits = t_logits , labels = labels )
497+ ce_teacher_sum = jnp .sum (ce_teacher_per_pos * mask )
498+
499+ # --- Always-T=1 KL for cross-run / cross-anneal comparability ---
500+ log_s_1 = jax .nn .log_softmax (s_logits , axis = - 1 )
501+ t_p_1 = jax .nn .softmax (t_logits , axis = - 1 )
502+ kl_t1_per_pos = optax .kl_divergence (log_s_1 , t_p_1 )
503+ kl_t1_sum = jnp .sum (kl_t1_per_pos * mask )
504+
505+ base_logit_loss = (alpha * soft_loss_mean ) + ((1.0 - alpha ) * hard_loss_mean )
506+
507+ feature_loss = jnp .array (0.0 , dtype = jnp .float32 )
466508 if self .beta_feature > 0.0 :
467-
468509 if self .layer_indices is not None :
469- # jnp.take slices along axis=0 (the layer dimension)
470510 s_features_sliced = jnp .take (s_features , self .layer_indices , axis = 0 )
471511 t_features_sliced = jnp .take (t_features , self .layer_indices , axis = 0 )
472512 else :
@@ -480,37 +520,59 @@ def compute_loss(
480520
481521 total_loss = base_logit_loss + feature_loss
482522
483- # 4. Return Loss AND Metrics (log dynamic values for TensorBoard verification)
484- metrics = {
485- "distill/soft_loss" : soft_loss ,
486- "distill/hard_loss" : hard_loss ,
487- "distill/kl_div" : jnp .mean (kl_div , where = mean_mask ),
488- "distill/teacher_loss" : teacher_hard_loss ,
489- "distill/out_proj_feature_loss" : feature_loss ,
490- "distill/total_loss" : total_loss ,
491- "distill/temperature" : temperature ,
492- "distill/alpha" : alpha ,
493- "distill/beta_feature" : beta_feature ,
523+ # Per-step next-token perplexity. Note: this is mean(exp(per-step CE)), not
524+ # exp(window-CE-mean) — close to true perplexity in steady state. For the exact
525+ # perplexity over a logging window compute exp(distill/hard_loss) on the TB side.
526+ teacher_loss_mean = ce_teacher_sum / safe_count
527+ student_perplexity_step = jnp .exp (jnp .minimum (hard_loss_mean , _PPL_CE_CAP ))
528+ teacher_perplexity_step = jnp .exp (jnp .minimum (teacher_loss_mean , _PPL_CE_CAP ))
529+
530+ one = jnp .array (1.0 , dtype = jnp .float32 )
531+ metrics : dict [str , tuple [jax .Array , jax .Array ]] = {
532+ # Token-weighted: emit (sum, valid_count) so multi-host averaging is unbiased.
533+ "distill/soft_loss" : (soft_loss_sum_scaled , valid_count ),
534+ "distill/hard_loss" : (ce_student_sum , valid_count ),
535+ "distill/teacher_loss" : (ce_teacher_sum , valid_count ),
536+ # Next-token prediction perplexity (per-step approximation of exp(hard_loss)).
537+ # The headline `_train_perplexity` Tunix prints is exp(total_loss) which for
538+ # distillation is exp(α·soft + (1-α)·hard + β·feature) and NOT next-token PPL.
539+ "distill/student_perplexity" : (student_perplexity_step , one ),
540+ "distill/teacher_perplexity" : (teacher_perplexity_step , one ),
541+ # KL at the current (scheduled) temperature T, without the T^2 scaling
542+ # that soft_loss applies. Pair with kl_div_T1 to compare T vs T=1.
543+ "distill/kl_div_at_T" : (kl_softened_sum , valid_count ),
544+ # KL at T=1: comparable across runs / annealing schedules.
545+ "distill/kl_div_T1" : (kl_t1_sum , valid_count ),
546+ # Per-step quantities: (value, 1.0) so the aggregator yields a simple mean over steps.
547+ "distill/out_proj_feature_loss" : (feature_loss , one ),
548+ "distill/total_loss" : (total_loss , one ),
549+ "distill/temperature" : (temperature , one ),
550+ "distill/alpha" : (alpha , one ),
551+ "distill/beta_feature" : (beta_feature , one ),
494552 }
495553 return total_loss , metrics
496554
497555 def compute_eval_loss (
498556 self ,
499557 student_output : DistillationForwardOutput ,
500558 labels : jax .Array ,
501- ) -> tuple [jax .Array , dict [str , jax .Array ]]:
502- """Computes Eval Loss and returns empty aux dict (required for consistency)."""
503- # Parent logic for task loss
504- # We re-implement simple CE here to ensure float32 casting
559+ ) -> tuple [jax .Array , dict [str , tuple [jax .Array , jax .Array ]]]:
560+ """Computes Eval Loss. Returns (loss, metrics) with (sum, count) metric pairs."""
505561 s_logits = student_output .logits .astype (jnp .float32 )
506562
507- labels_mask = jnp .any (labels != 0 , axis = - 1 , keepdims = True )
508- mean_mask = jnp .squeeze (labels_mask , axis = - 1 )
509- ce_loss = optax .softmax_cross_entropy (logits = s_logits , labels = labels , where = labels_mask )
510- task_loss = jnp .mean (ce_loss , where = mean_mask )
563+ mask = jnp .any (labels != 0 , axis = - 1 ).astype (jnp .float32 )
564+ valid_count = jnp .sum (mask )
565+ safe_count = jnp .maximum (valid_count , 1.0 )
566+
567+ ce_per_pos = optax .softmax_cross_entropy (logits = s_logits , labels = labels )
568+ ce_sum = jnp .sum (ce_per_pos * mask )
569+ task_loss = ce_sum / safe_count
511570
512- # Must return a tuple because _has_aux=True expects it
513- return task_loss , {}
571+ metrics = {
572+ "eval/hard_loss" : (ce_sum , valid_count ),
573+ "eval/student_perplexity" : (jnp .exp (jnp .minimum (task_loss , _PPL_CE_CAP )), jnp .array (1.0 , dtype = jnp .float32 )),
574+ }
575+ return task_loss , metrics
514576
515577 def create_labels (self , targets , targets_segmentation = None , ** kwargs ):
516578 """Converts integer targets to masked one-hot vectors for hard label loss."""
@@ -574,11 +636,13 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F
574636 if not force and not self ._checkpoint_manager .should_save (step ):
575637 return False
576638
577- # Standard Tunix Logic for Model/Optimizer
639+ # Standard Tunix Logic for Model/Optimizer.
640+ # Accept either a ModelBundle (common path) or a plain nnx module.
641+ target_model = getattr (model , "student_model" , model )
578642 if save_only_lora_params :
579- params = nnx .state (model . student_model , nnx .LoRAParam )
643+ params = nnx .state (target_model , nnx .LoRAParam )
580644 else :
581- params = nnx .state (model . student_model )
645+ params = nnx .state (target_model )
582646
583647 # Define standard SaveArgs once to reuse
584648 default_save_args = checkpoint .SaveArgs ()
@@ -625,71 +689,27 @@ def maybe_restore(
625689 optimizer : Any = None ,
626690 restore_only_lora_params : bool = False ,
627691 ) -> tuple [int , dict [str , Any ]]:
628- """Restores model and optimizer state if a checkpoint exists, using correct sharding specs.
629-
630- This method checks for the latest available checkpoint. If found, it restores the
631- model parameters and optionally the optimizer state in-place. It automatically
632- maps the parameter's `sharding` attributes to Orbax restore arguments to ensure
633- the tensors are placed on the correct device meshes.
692+ """Restores model + optimizer by delegating to upstream Tunix.
634693
635- Args:
636- model: The model to restore. If a `ModelBundle` is provided, it automatically
637- extracts and restores only the `student_model`.
638- optimizer: The optimizer state to restore. If None, optimizer restoration is skipped.
639- restore_only_lora_params: If True, restricts restoration to parameters marked
640- as `nnx.LoRAParam`.
694+ Unwraps `ModelBundle` if present (we only restore `student_model`).
641695
642696 Returns:
643- A tuple containing the restored step number (0 if no checkpoint was found)
644- and a dictionary of custom metadata.
697+ (restored step, custom_metadata dict). Step is 0 if no checkpoint exists.
645698 """
646699 if self ._checkpoint_manager is None :
647700 return 0 , {}
648701
649- step = self ._checkpoint_manager .latest_step ()
650- if step is None :
651- return 0 , {}
652-
653- max_logging .log (f"Restoring from checkpoint step { step } ..." )
654-
655- # Extract student model safely
656702 target_model = getattr (model , "student_model" , model )
657703
658- if restore_only_lora_params :
659- params = nnx .state (target_model , nnx .LoRAParam )
660- else :
661- params = nnx .state (target_model )
662-
663- def map_to_pspec (data ):
664- if hasattr (data , "sharding" ):
665- return checkpoint .type_handlers .ArrayRestoreArgs (sharding = data .sharding )
666- return None
667-
668- restore_args = jax .tree .map (map_to_pspec , params )
669-
670- cp_restore_args = {
671- "model_params" : checkpoint .args .PyTreeRestore (
672- item = params ,
673- restore_args = restore_args ,
674- )
675- }
676-
677- if optimizer is not None :
678- optimizer_state = nnx .state (optimizer , nnx .optimizer .OptState )
679- opt_restore_args = jax .tree .map (map_to_pspec , optimizer_state )
680- cp_restore_args ["optimizer_state" ] = checkpoint .args .PyTreeRestore (
681- item = optimizer_state ,
682- restore_args = opt_restore_args ,
683- )
684-
685- restored = self ._checkpoint_manager .restore (
686- step ,
687- args = checkpoint .args .Composite (** cp_restore_args ),
704+ step , _ = super ().maybe_restore (
705+ model = target_model ,
706+ optimizer = optimizer ,
707+ restore_only_lora_params = restore_only_lora_params ,
688708 )
709+ if step == 0 :
710+ return 0 , {}
689711
690- nnx .update (target_model , restored .model_params )
691- if optimizer is not None :
692- nnx .update (optimizer , restored .optimizer_state )
712+ max_logging .log (f"Restored from checkpoint step { step } ." )
693713
694714 metadata = self ._checkpoint_manager .metadata (step )
695715 if metadata and hasattr (metadata , "custom_metadata" ) and metadata .custom_metadata is not None :
0 commit comments