Skip to content

Commit 25f7ba0

Browse files
Merge pull request #3701 from AI-Hypercomputer:agagik-checkpoint
PiperOrigin-RevId: 903534582
2 parents bdae55d + 190a12d commit 25f7ba0

6 files changed

Lines changed: 1002 additions & 180 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 139 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -18,25 +18,25 @@
1818
model structures with Tunix's training interfaces.
1919
"""
2020

21-
import pickle
22-
import tensorflow as tf
23-
from array_record.python import array_record_module
24-
2521
import abc
26-
from typing import Any, Iterator, Optional, List, Callable, Literal
22+
import pickle
23+
from typing import Any, Callable, Iterator, List, Literal, Optional, Sequence
2724

2825
import flax
2926
from flax import nnx
3027
import jax
3128
import jax.numpy as jnp
29+
import numpy as np
3230
import optax
31+
import tensorflow as tf
32+
from array_record.python import array_record_module
3333
from orbax import checkpoint
3434

3535
from maxtext.utils import max_logging
3636
# Reuse MaxText's native checkpointing logic
3737
from maxtext.common.checkpointing import GrainCheckpointHandler, GrainCheckpointSave, GrainCheckpointRestore
38-
from tunix.sft import peft_trainer
3938
from tunix.sft import checkpoint_manager as tunix_checkpoint_manager
39+
from tunix.sft import peft_trainer
4040

4141

4242
# -----------------------------------------------------------------------------
@@ -184,6 +184,11 @@ def __next__(self) -> MaxTextTrainingInput:
184184
# Distillation Strategy
185185
# -----------------------------------------------------------------------------
186186

187+
# Clamp CE before exp() so a divergence spike doesn't poison PPL averages
188+
# with inf. 20 nats is well above plausible CE (Llama random-init ~11.76)
189+
# and far below fp32 exp overflow (~88).
190+
_PPL_CE_CAP = 20.0
191+
187192

188193
def compute_schedule(
189194
step: jax.Array,
@@ -217,6 +222,33 @@ def compute_schedule(
217222
raise ValueError(f"Unsupported schedule_type: {schedule_type!r}. Must be 'constant', 'linear', or 'cosine'.")
218223

219224

225+
def weighted_mean(sum_count_pairs: Sequence[tuple[Any, Any]] | np.ndarray) -> float:
226+
"""Aggregates `(sum, count)` pairs into a single token-weighted mean.
227+
228+
Used as the aggregation function for metrics emitted by `compute_loss` and
229+
`compute_eval_loss`. Robust to per-host imbalance and to varying valid-token
230+
counts across logging steps:
231+
final_value = sum(sums) / sum(counts)
232+
233+
Accepts either a list of (sum, count) tuples or an ndarray of shape (N, 2).
234+
Tunix's metrics writer can pass either form, so we normalize here.
235+
236+
Returns 0.0 for an empty input or when total count is non-positive.
237+
"""
238+
arr = np.asarray(sum_count_pairs, dtype=np.float32)
239+
if arr.size == 0:
240+
return 0.0
241+
# Normalize shape. Single pair -> (1, 2); list of pairs -> (N, 2).
242+
if arr.ndim == 1:
243+
arr = arr.reshape(1, -1)
244+
if arr.ndim != 2 or arr.shape[1] != 2:
245+
return 0.0
246+
total = float(arr[:, 1].sum())
247+
if total <= 0.0:
248+
return 0.0
249+
return float(arr[:, 0].sum() / total)
250+
251+
220252
class DistillationStrategy(abc.ABC):
221253
"""Abstract base class for MaxText Distillation Strategies."""
222254

@@ -312,15 +344,14 @@ def __init__(
312344
Args:
313345
student_forward_fn: Function to compute student model outputs.
314346
teacher_forward_fn: Function to compute teacher model outputs.
315-
labels_fn: Function to compute labels from model inputs.
316347
temperature: Temperature for softening probabilities (> 0).
317348
alpha: Weight to balance distillation loss and task loss (0.0 to 1.0).
318349
beta_feature: Weight to balance feature loss (0.0 to 1.0). 0.0 disables feature loss.
319350
layer_indices: Layer indices to apply feature loss.
320351
feature_loss_type: The type of feature loss to use if `feature_loss_fn` is None.
321352
Can be "cosine" (default) or "l2".
322-
feature_loss_fn: A function that takes two jax. Arrays (student_map,
323-
teacher_map) and returns a scalar loss. Defaults to Cosine Distance.
353+
feature_loss_fn: A function that takes two jax.Arrays (student_map,
354+
teacher_map) and returns a scalar loss. Defaults to cosine distance.
324355
cosine_distance_axis: The axis to use for cosine distance computation if
325356
feature_loss_fn is not provided. Defaults to -1.
326357
alpha_end: Target alpha value at end of training. None keeps alpha fixed.
@@ -418,18 +449,19 @@ def compute_loss(
418449
teacher_output: DistillationForwardOutput,
419450
labels: jax.Array,
420451
step: jax.Array | None = None,
421-
) -> tuple[jax.Array, dict[str, jax.Array]]:
422-
"""Computes Loss and Auxiliary Metrics."""
452+
) -> tuple[jax.Array, dict[str, tuple[jax.Array, jax.Array]]]:
453+
"""Computes Loss and Auxiliary Metrics.
454+
455+
Metrics are emitted as (sum, count) pairs so that they can be aggregated
456+
across hosts and across logging windows in a token-weighted (unbiased) way:
457+
final_value = sum(sums) / sum(counts).
458+
"""
423459
# Resolve scheduled weights for this step
424460
alpha, temperature, beta_feature = self._get_scheduled_weights(step)
425461

426-
# Calculate Distillation Loss (KL Divergence)
427-
# Scale logits by temperature T for soft targets
428-
# We use explicit float32 casting for stability in loss calculation
429462
s_logits = student_output.logits.astype(jnp.float32)
430463
t_logits = teacher_output.logits.astype(jnp.float32)
431464

432-
# Shape: [num_layers, batch, seq, hidden_dim]
433465
s_features = student_output.out_projection_activations
434466
t_features = teacher_output.out_projection_activations
435467

@@ -439,34 +471,42 @@ def compute_loss(
439471
"Ensure the model architecture supports feature extraction (e.g., 'out_projection_activations' is sowed)."
440472
)
441473

442-
log_student_probs_temp = jax.nn.log_softmax(s_logits / temperature, axis=-1)
443-
teacher_probs_temp = jax.nn.softmax(t_logits / temperature, axis=-1)
444-
# labels are supposed to have all sft masks applied by this moment
445-
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
446-
mean_mask = jnp.squeeze(labels_mask, axis=-1)
447-
448-
# KL(Teacher || Student)
449-
kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp, where=labels_mask)
450-
451-
# Scale gradients by T^2 (Hinton et al.)
452-
soft_loss = jnp.mean(kl_div, where=mean_mask) * (temperature**2)
453-
454-
# 1. Student Hard Loss (Existing)
455-
ce_loss_student = optax.softmax_cross_entropy(logits=s_logits, labels=labels, where=labels_mask)
456-
hard_loss = jnp.mean(ce_loss_student, where=mean_mask)
457-
458-
# 2. Teacher Hard Loss (For Verification)
459-
ce_loss_teacher = optax.softmax_cross_entropy(logits=t_logits, labels=labels, where=labels_mask)
460-
teacher_hard_loss = jnp.mean(ce_loss_teacher, where=mean_mask)
461-
462-
# 3. Combine losses
463-
base_logit_loss = (alpha * soft_loss) + ((1.0 - alpha) * hard_loss)
464-
465-
feature_loss = jnp.array(0.0)
474+
# Per-token validity mask, derived from the one-hot labels so we don't need
475+
# a separate mask input. A padded (fully-zero) row yields `any != 0 == False`.
476+
mask = jnp.any(labels != 0, axis=-1).astype(jnp.float32) # [B, T]
477+
valid_count = jnp.sum(mask)
478+
safe_count = jnp.maximum(valid_count, 1.0)
479+
480+
# --- Soft loss: KL on temperature-softened distributions ---
481+
log_s_T = jax.nn.log_softmax(s_logits / temperature, axis=-1)
482+
t_p_T = jax.nn.softmax(t_logits / temperature, axis=-1)
483+
# KL(teacher || student) per position. optax.kl_divergence(log_pred, target) = KL(target || pred).
484+
kl_softened_per_pos = optax.kl_divergence(log_s_T, t_p_T) # [B, T]
485+
kl_softened_sum = jnp.sum(kl_softened_per_pos * mask)
486+
# Scale by T^2 (Hinton). Apply once at the loss; logged metric is the scaled sum too.
487+
soft_loss_sum_scaled = kl_softened_sum * (temperature**2)
488+
soft_loss_mean = soft_loss_sum_scaled / safe_count
489+
490+
# --- Hard loss: student CE against ground-truth ---
491+
ce_student_per_pos = optax.softmax_cross_entropy(logits=s_logits, labels=labels)
492+
ce_student_sum = jnp.sum(ce_student_per_pos * mask)
493+
hard_loss_mean = ce_student_sum / safe_count
494+
495+
# --- Teacher CE (verification metric) ---
496+
ce_teacher_per_pos = optax.softmax_cross_entropy(logits=t_logits, labels=labels)
497+
ce_teacher_sum = jnp.sum(ce_teacher_per_pos * mask)
498+
499+
# --- Always-T=1 KL for cross-run / cross-anneal comparability ---
500+
log_s_1 = jax.nn.log_softmax(s_logits, axis=-1)
501+
t_p_1 = jax.nn.softmax(t_logits, axis=-1)
502+
kl_t1_per_pos = optax.kl_divergence(log_s_1, t_p_1)
503+
kl_t1_sum = jnp.sum(kl_t1_per_pos * mask)
504+
505+
base_logit_loss = (alpha * soft_loss_mean) + ((1.0 - alpha) * hard_loss_mean)
506+
507+
feature_loss = jnp.array(0.0, dtype=jnp.float32)
466508
if self.beta_feature > 0.0:
467-
468509
if self.layer_indices is not None:
469-
# jnp.take slices along axis=0 (the layer dimension)
470510
s_features_sliced = jnp.take(s_features, self.layer_indices, axis=0)
471511
t_features_sliced = jnp.take(t_features, self.layer_indices, axis=0)
472512
else:
@@ -480,37 +520,59 @@ def compute_loss(
480520

481521
total_loss = base_logit_loss + feature_loss
482522

483-
# 4. Return Loss AND Metrics (log dynamic values for TensorBoard verification)
484-
metrics = {
485-
"distill/soft_loss": soft_loss,
486-
"distill/hard_loss": hard_loss,
487-
"distill/kl_div": jnp.mean(kl_div, where=mean_mask),
488-
"distill/teacher_loss": teacher_hard_loss,
489-
"distill/out_proj_feature_loss": feature_loss,
490-
"distill/total_loss": total_loss,
491-
"distill/temperature": temperature,
492-
"distill/alpha": alpha,
493-
"distill/beta_feature": beta_feature,
523+
# Per-step next-token perplexity. Note: this is mean(exp(per-step CE)), not
524+
# exp(window-CE-mean) — close to true perplexity in steady state. For the exact
525+
# perplexity over a logging window compute exp(distill/hard_loss) on the TB side.
526+
teacher_loss_mean = ce_teacher_sum / safe_count
527+
student_perplexity_step = jnp.exp(jnp.minimum(hard_loss_mean, _PPL_CE_CAP))
528+
teacher_perplexity_step = jnp.exp(jnp.minimum(teacher_loss_mean, _PPL_CE_CAP))
529+
530+
one = jnp.array(1.0, dtype=jnp.float32)
531+
metrics: dict[str, tuple[jax.Array, jax.Array]] = {
532+
# Token-weighted: emit (sum, valid_count) so multi-host averaging is unbiased.
533+
"distill/soft_loss": (soft_loss_sum_scaled, valid_count),
534+
"distill/hard_loss": (ce_student_sum, valid_count),
535+
"distill/teacher_loss": (ce_teacher_sum, valid_count),
536+
# Next-token prediction perplexity (per-step approximation of exp(hard_loss)).
537+
# The headline `_train_perplexity` Tunix prints is exp(total_loss) which for
538+
# distillation is exp(α·soft + (1-α)·hard + β·feature) and NOT next-token PPL.
539+
"distill/student_perplexity": (student_perplexity_step, one),
540+
"distill/teacher_perplexity": (teacher_perplexity_step, one),
541+
# KL at the current (scheduled) temperature T, without the T^2 scaling
542+
# that soft_loss applies. Pair with kl_div_T1 to compare T vs T=1.
543+
"distill/kl_div_at_T": (kl_softened_sum, valid_count),
544+
# KL at T=1: comparable across runs / annealing schedules.
545+
"distill/kl_div_T1": (kl_t1_sum, valid_count),
546+
# Per-step quantities: (value, 1.0) so the aggregator yields a simple mean over steps.
547+
"distill/out_proj_feature_loss": (feature_loss, one),
548+
"distill/total_loss": (total_loss, one),
549+
"distill/temperature": (temperature, one),
550+
"distill/alpha": (alpha, one),
551+
"distill/beta_feature": (beta_feature, one),
494552
}
495553
return total_loss, metrics
496554

497555
def compute_eval_loss(
498556
self,
499557
student_output: DistillationForwardOutput,
500558
labels: jax.Array,
501-
) -> tuple[jax.Array, dict[str, jax.Array]]:
502-
"""Computes Eval Loss and returns empty aux dict (required for consistency)."""
503-
# Parent logic for task loss
504-
# We re-implement simple CE here to ensure float32 casting
559+
) -> tuple[jax.Array, dict[str, tuple[jax.Array, jax.Array]]]:
560+
"""Computes Eval Loss. Returns (loss, metrics) with (sum, count) metric pairs."""
505561
s_logits = student_output.logits.astype(jnp.float32)
506562

507-
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
508-
mean_mask = jnp.squeeze(labels_mask, axis=-1)
509-
ce_loss = optax.softmax_cross_entropy(logits=s_logits, labels=labels, where=labels_mask)
510-
task_loss = jnp.mean(ce_loss, where=mean_mask)
563+
mask = jnp.any(labels != 0, axis=-1).astype(jnp.float32)
564+
valid_count = jnp.sum(mask)
565+
safe_count = jnp.maximum(valid_count, 1.0)
566+
567+
ce_per_pos = optax.softmax_cross_entropy(logits=s_logits, labels=labels)
568+
ce_sum = jnp.sum(ce_per_pos * mask)
569+
task_loss = ce_sum / safe_count
511570

512-
# Must return a tuple because _has_aux=True expects it
513-
return task_loss, {}
571+
metrics = {
572+
"eval/hard_loss": (ce_sum, valid_count),
573+
"eval/student_perplexity": (jnp.exp(jnp.minimum(task_loss, _PPL_CE_CAP)), jnp.array(1.0, dtype=jnp.float32)),
574+
}
575+
return task_loss, metrics
514576

515577
def create_labels(self, targets, targets_segmentation=None, **kwargs):
516578
"""Converts integer targets to masked one-hot vectors for hard label loss."""
@@ -574,11 +636,13 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F
574636
if not force and not self._checkpoint_manager.should_save(step):
575637
return False
576638

577-
# Standard Tunix Logic for Model/Optimizer
639+
# Standard Tunix Logic for Model/Optimizer.
640+
# Accept either a ModelBundle (common path) or a plain nnx module.
641+
target_model = getattr(model, "student_model", model)
578642
if save_only_lora_params:
579-
params = nnx.state(model.student_model, nnx.LoRAParam)
643+
params = nnx.state(target_model, nnx.LoRAParam)
580644
else:
581-
params = nnx.state(model.student_model)
645+
params = nnx.state(target_model)
582646

583647
# Define standard SaveArgs once to reuse
584648
default_save_args = checkpoint.SaveArgs()
@@ -625,71 +689,27 @@ def maybe_restore(
625689
optimizer: Any = None,
626690
restore_only_lora_params: bool = False,
627691
) -> tuple[int, dict[str, Any]]:
628-
"""Restores model and optimizer state if a checkpoint exists, using correct sharding specs.
629-
630-
This method checks for the latest available checkpoint. If found, it restores the
631-
model parameters and optionally the optimizer state in-place. It automatically
632-
maps the parameter's `sharding` attributes to Orbax restore arguments to ensure
633-
the tensors are placed on the correct device meshes.
692+
"""Restores model + optimizer by delegating to upstream Tunix.
634693
635-
Args:
636-
model: The model to restore. If a `ModelBundle` is provided, it automatically
637-
extracts and restores only the `student_model`.
638-
optimizer: The optimizer state to restore. If None, optimizer restoration is skipped.
639-
restore_only_lora_params: If True, restricts restoration to parameters marked
640-
as `nnx.LoRAParam`.
694+
Unwraps `ModelBundle` if present (we only restore `student_model`).
641695
642696
Returns:
643-
A tuple containing the restored step number (0 if no checkpoint was found)
644-
and a dictionary of custom metadata.
697+
(restored step, custom_metadata dict). Step is 0 if no checkpoint exists.
645698
"""
646699
if self._checkpoint_manager is None:
647700
return 0, {}
648701

649-
step = self._checkpoint_manager.latest_step()
650-
if step is None:
651-
return 0, {}
652-
653-
max_logging.log(f"Restoring from checkpoint step {step}...")
654-
655-
# Extract student model safely
656702
target_model = getattr(model, "student_model", model)
657703

658-
if restore_only_lora_params:
659-
params = nnx.state(target_model, nnx.LoRAParam)
660-
else:
661-
params = nnx.state(target_model)
662-
663-
def map_to_pspec(data):
664-
if hasattr(data, "sharding"):
665-
return checkpoint.type_handlers.ArrayRestoreArgs(sharding=data.sharding)
666-
return None
667-
668-
restore_args = jax.tree.map(map_to_pspec, params)
669-
670-
cp_restore_args = {
671-
"model_params": checkpoint.args.PyTreeRestore(
672-
item=params,
673-
restore_args=restore_args,
674-
)
675-
}
676-
677-
if optimizer is not None:
678-
optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState)
679-
opt_restore_args = jax.tree.map(map_to_pspec, optimizer_state)
680-
cp_restore_args["optimizer_state"] = checkpoint.args.PyTreeRestore(
681-
item=optimizer_state,
682-
restore_args=opt_restore_args,
683-
)
684-
685-
restored = self._checkpoint_manager.restore(
686-
step,
687-
args=checkpoint.args.Composite(**cp_restore_args),
704+
step, _ = super().maybe_restore(
705+
model=target_model,
706+
optimizer=optimizer,
707+
restore_only_lora_params=restore_only_lora_params,
688708
)
709+
if step == 0:
710+
return 0, {}
689711

690-
nnx.update(target_model, restored.model_params)
691-
if optimizer is not None:
692-
nnx.update(optimizer, restored.optimizer_state)
712+
max_logging.log(f"Restored from checkpoint step {step}.")
693713

694714
metadata = self._checkpoint_manager.metadata(step)
695715
if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:

0 commit comments

Comments
 (0)