Skip to content

Commit dbc1584

Browse files
Merge pull request #3683 from AI-Hypercomputer:agagik-dynamic-distill
PiperOrigin-RevId: 901504218
2 parents 147168c + 9736256 commit dbc1584

6 files changed

Lines changed: 823 additions & 14 deletions

File tree

src/maxtext/configs/base.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,14 @@ distill_beta: 0.0
12231223
# distill_feature_loss_type is the type of loss to use for feature distillation ("cosine" or "l2").
12241224
distill_feature_loss_type: "cosine"
12251225
distill_layer_indices: None
1226+
# Dynamic loss weight scheduling: set *_end to a target value and *_schedule to "linear" or "cosine".
1227+
# When *_end is None (default), the corresponding weight stays fixed throughout training.
1228+
distill_alpha_end: None
1229+
distill_alpha_schedule: "constant"
1230+
distill_temperature_end: None
1231+
distill_temperature_schedule: "constant"
1232+
distill_beta_end: None
1233+
distill_beta_schedule: "constant"
12261234

12271235
##### Elastic training parameters
12281236
# Elastic training is Pathways-specific and does not work on McJAX.

src/maxtext/configs/types.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,22 @@ class Distillation(BaseModel):
11601160
"cosine", description="The type of loss to use for feature distillation ('cosine' or 'l2')."
11611161
)
11621162
distill_layer_indices: None | list = Field(None, description="Feature indices for feature loss.")
1163+
distill_alpha_end: Optional[float] = Field(None, description="Target alpha at end of training. None keeps alpha fixed.")
1164+
distill_alpha_schedule: Literal["constant", "linear", "cosine"] = Field(
1165+
"constant", description="Schedule type for alpha annealing ('constant', 'linear', or 'cosine')."
1166+
)
1167+
distill_temperature_end: Optional[float] = Field(
1168+
None, description="Target temperature at end of training. None keeps temperature fixed."
1169+
)
1170+
distill_temperature_schedule: Literal["constant", "linear", "cosine"] = Field(
1171+
"constant", description="Schedule type for temperature annealing ('constant', 'linear', or 'cosine')."
1172+
)
1173+
distill_beta_end: Optional[float] = Field(
1174+
None, description="Target beta_feature at end of training. None keeps beta fixed."
1175+
)
1176+
distill_beta_schedule: Literal["constant", "linear", "cosine"] = Field(
1177+
"constant", description="Schedule type for beta annealing ('constant', 'linear', or 'cosine')."
1178+
)
11631179

11641180
# --- Distillation freezing filter --
11651181
student_params_to_update: None | list = Field(
@@ -2251,6 +2267,30 @@ def validate_and_set_hlo_dump_defaults():
22512267
if not self.enable_nnx:
22522268
raise ValueError("a value of self.distill_beta > 0.0 requires self.enable_nnx = True")
22532269

2270+
# Validate distillation schedule parameters
2271+
if self.distill_alpha_end is not None and not 0.0 <= self.distill_alpha_end <= 1.0:
2272+
raise ValueError(f"distill_alpha_end must be in [0, 1], got {self.distill_alpha_end}")
2273+
if self.distill_temperature_end is not None and self.distill_temperature_end <= 0.0:
2274+
raise ValueError(f"distill_temperature_end must be > 0, got {self.distill_temperature_end}")
2275+
if self.distill_beta_end is not None and self.distill_beta_end < 0.0:
2276+
raise ValueError(f"distill_beta_end must be >= 0, got {self.distill_beta_end}")
2277+
if self.distill_beta == 0.0 and self.distill_beta_end is not None and self.distill_beta_end > 0.0:
2278+
raise ValueError(
2279+
f"distill_beta=0.0 but distill_beta_end={self.distill_beta_end}. "
2280+
"Feature extraction is disabled when distill_beta starts at 0.0. "
2281+
"Set distill_beta to a small positive value (e.g., 1e-6) to enable feature extraction."
2282+
)
2283+
for param_name, schedule, end_value in [
2284+
("distill_alpha", self.distill_alpha_schedule, self.distill_alpha_end),
2285+
("distill_temperature", self.distill_temperature_schedule, self.distill_temperature_end),
2286+
("distill_beta", self.distill_beta_schedule, self.distill_beta_end),
2287+
]:
2288+
if schedule != "constant" and end_value is None:
2289+
raise ValueError(
2290+
f"{param_name}_schedule is '{schedule}' but {param_name}_end is None. "
2291+
f"Set {param_name}_end to a target value or use schedule='constant'."
2292+
)
2293+
22542294
# D. CALCULATE MODEL DIMENSIONS from global_parameter_scale
22552295
# This allows scaling the model size up or down easily with a single power-of-two factor.
22562296
emb_scale, num_head_scale, mlp_dim_scale, layer_scale = get_individual_scales(self.global_parameter_scale)

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 117 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,38 @@ def __next__(self) -> MaxTextTrainingInput:
185185
# -----------------------------------------------------------------------------
186186

187187

188+
def compute_schedule(
189+
step: jax.Array,
190+
max_steps: int,
191+
start_value: float,
192+
end_value: float | None,
193+
schedule_type: str,
194+
) -> jax.Array:
195+
"""Computes a scheduled value based on training progress.
196+
197+
Args:
198+
step: Current training step as a JAX array.
199+
max_steps: Total number of training steps.
200+
start_value: Value at the beginning of training.
201+
end_value: Value at the end of training. If None, returns start_value.
202+
schedule_type: One of "constant", "linear", or "cosine".
203+
204+
Returns:
205+
The scheduled value as a JAX scalar.
206+
"""
207+
if end_value is None or schedule_type == "constant":
208+
return jnp.array(start_value, dtype=jnp.float32)
209+
210+
progress = jnp.clip(step.astype(jnp.float32) / max_steps, 0.0, 1.0)
211+
212+
if schedule_type == "linear":
213+
return start_value + (end_value - start_value) * progress
214+
elif schedule_type == "cosine":
215+
return end_value + (start_value - end_value) * 0.5 * (1.0 + jnp.cos(jnp.pi * progress))
216+
else:
217+
raise ValueError(f"Unsupported schedule_type: {schedule_type!r}. Must be 'constant', 'linear', or 'cosine'.")
218+
219+
188220
class DistillationStrategy(abc.ABC):
189221
"""Abstract base class for MaxText Distillation Strategies."""
190222

@@ -210,13 +242,15 @@ def compute_loss(
210242
student_output: "DistillationForwardOutput",
211243
teacher_output: "DistillationForwardOutput",
212244
labels: jax.Array,
245+
step: jax.Array | None = None,
213246
) -> tuple[jax.Array, dict[str, jax.Array]]:
214247
"""Computes the distillation loss.
215248
216249
Args:
217250
student_output: The forward pass output of the student model.
218251
teacher_output: The forward pass output of the frozen teacher model.
219252
labels: The masked one-hot encoded ground truth labels.
253+
step: Current training step for dynamic scheduling. If None, uses fixed values.
220254
221255
Returns:
222256
A tuple containing the scalar loss and a dictionary of auxiliary metrics
@@ -265,8 +299,15 @@ def __init__(
265299
feature_loss_type: Literal["cosine", "l2"] = "cosine",
266300
cosine_distance_axis: int | tuple[int, ...] = -1,
267301
vocab_size: int = 0,
302+
alpha_end: float | None = None,
303+
alpha_schedule: Literal["constant", "linear", "cosine"] = "constant",
304+
temperature_end: float | None = None,
305+
temperature_schedule: Literal["constant", "linear", "cosine"] = "constant",
306+
beta_end: float | None = None,
307+
beta_schedule: Literal["constant", "linear", "cosine"] = "constant",
308+
max_steps: int = 1,
268309
):
269-
"""Initializes the Combined strategy using tunix logit.LogitStrategy.
310+
"""Initializes the Combined distillation strategy.
270311
271312
Args:
272313
student_forward_fn: Function to compute student model outputs.
@@ -282,6 +323,13 @@ def __init__(
282323
teacher_map) and returns a scalar loss. Defaults to Cosine Distance.
283324
cosine_distance_axis: The axis to use for cosine distance computation if
284325
feature_loss_fn is not provided. Defaults to -1.
326+
alpha_end: Target alpha value at end of training. None keeps alpha fixed.
327+
alpha_schedule: Schedule type for alpha annealing.
328+
temperature_end: Target temperature at end of training. None keeps temperature fixed.
329+
temperature_schedule: Schedule type for temperature annealing.
330+
beta_end: Target beta_feature value at end of training. None keeps beta fixed.
331+
beta_schedule: Schedule type for beta annealing.
332+
max_steps: Total training steps, used for schedule computation.
285333
"""
286334

287335
super().__init__(
@@ -296,6 +344,39 @@ def __init__(
296344
self.beta_feature = beta_feature
297345
self.layer_indices = jnp.array(layer_indices) if layer_indices is not None else None
298346

347+
# Schedule parameters
348+
self.alpha_end = alpha_end
349+
self.alpha_schedule = alpha_schedule
350+
self.temperature_end = temperature_end
351+
self.temperature_schedule = temperature_schedule
352+
self.beta_end = beta_end
353+
self.beta_schedule = beta_schedule
354+
self.max_steps = max_steps
355+
356+
# Validate schedule parameter ranges
357+
if alpha_end is not None and not 0.0 <= alpha_end <= 1.0:
358+
raise ValueError(f"alpha_end must be in [0, 1], got {alpha_end}")
359+
if temperature_end is not None and temperature_end <= 0.0:
360+
raise ValueError(f"temperature_end must be > 0, got {temperature_end}")
361+
if beta_end is not None and beta_end < 0.0:
362+
raise ValueError(f"beta_end must be >= 0, got {beta_end}")
363+
if beta_feature == 0.0 and beta_end is not None and beta_end > 0.0:
364+
raise ValueError(
365+
f"distill_beta=0.0 but distill_beta_end={beta_end}. Feature extraction is disabled when "
366+
"distill_beta starts at 0.0 (the model does not sow intermediate activations). "
367+
"Set distill_beta to a small positive value (e.g., 1e-6) to enable feature extraction."
368+
)
369+
for param_name, schedule, end_value in [
370+
("alpha", alpha_schedule, alpha_end),
371+
("temperature", temperature_schedule, temperature_end),
372+
("beta", beta_schedule, beta_end),
373+
]:
374+
if schedule != "constant" and end_value is None:
375+
raise ValueError(
376+
f"{param_name}_schedule is '{schedule}' but {param_name}_end is None. "
377+
f"Set {param_name}_end to a target value or use schedule='constant'."
378+
)
379+
299380
self.feature_loss_fn = feature_loss_fn
300381
if feature_loss_fn is None:
301382
if feature_loss_type == "cosine":
@@ -309,13 +390,39 @@ def __init__(
309390
else:
310391
raise ValueError(f"Unsupported feature_loss_type: {feature_loss_type!r}")
311392

393+
def _get_scheduled_weights(self, step: jax.Array | None) -> tuple[jax.Array, jax.Array, jax.Array]:
394+
"""Resolves the current alpha, temperature, and beta values from schedules.
395+
396+
Args:
397+
step: Current training step. If None, returns the fixed initial values.
398+
399+
Returns:
400+
A tuple of (alpha, temperature, beta_feature) as JAX scalars.
401+
"""
402+
if step is None:
403+
return (
404+
jnp.array(self.alpha, dtype=jnp.float32),
405+
jnp.array(self.temperature, dtype=jnp.float32),
406+
jnp.array(self.beta_feature, dtype=jnp.float32),
407+
)
408+
alpha = compute_schedule(step, self.max_steps, self.alpha, self.alpha_end, self.alpha_schedule)
409+
temperature = compute_schedule(
410+
step, self.max_steps, self.temperature, self.temperature_end, self.temperature_schedule
411+
)
412+
beta_feature = compute_schedule(step, self.max_steps, self.beta_feature, self.beta_end, self.beta_schedule)
413+
return alpha, temperature, beta_feature
414+
312415
def compute_loss(
313416
self,
314417
student_output: DistillationForwardOutput,
315418
teacher_output: DistillationForwardOutput,
316419
labels: jax.Array,
420+
step: jax.Array | None = None,
317421
) -> tuple[jax.Array, dict[str, jax.Array]]:
318422
"""Computes Loss and Auxiliary Metrics."""
423+
# Resolve scheduled weights for this step
424+
alpha, temperature, beta_feature = self._get_scheduled_weights(step)
425+
319426
# Calculate Distillation Loss (KL Divergence)
320427
# Scale logits by temperature T for soft targets
321428
# We use explicit float32 casting for stability in loss calculation
@@ -332,8 +439,8 @@ def compute_loss(
332439
"Ensure the model architecture supports feature extraction (e.g., 'out_projection_activations' is sowed)."
333440
)
334441

335-
log_student_probs_temp = jax.nn.log_softmax(s_logits / self.temperature, axis=-1)
336-
teacher_probs_temp = jax.nn.softmax(t_logits / self.temperature, axis=-1)
442+
log_student_probs_temp = jax.nn.log_softmax(s_logits / temperature, axis=-1)
443+
teacher_probs_temp = jax.nn.softmax(t_logits / temperature, axis=-1)
337444
# labels are supposed to have all sft masks applied by this moment
338445
labels_mask = jnp.any(labels != 0, axis=-1, keepdims=True)
339446
mean_mask = jnp.squeeze(labels_mask, axis=-1)
@@ -342,7 +449,7 @@ def compute_loss(
342449
kl_div = optax.kl_divergence(log_student_probs_temp, teacher_probs_temp, where=labels_mask)
343450

344451
# Scale gradients by T^2 (Hinton et al.)
345-
soft_loss = jnp.mean(kl_div, where=mean_mask) * (self.temperature**2)
452+
soft_loss = jnp.mean(kl_div, where=mean_mask) * (temperature**2)
346453

347454
# 1. Student Hard Loss (Existing)
348455
ce_loss_student = optax.softmax_cross_entropy(logits=s_logits, labels=labels, where=labels_mask)
@@ -353,7 +460,7 @@ def compute_loss(
353460
teacher_hard_loss = jnp.mean(ce_loss_teacher, where=mean_mask)
354461

355462
# 3. Combine losses
356-
base_logit_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss)
463+
base_logit_loss = (alpha * soft_loss) + ((1.0 - alpha) * hard_loss)
357464

358465
feature_loss = jnp.array(0.0)
359466
if self.beta_feature > 0.0:
@@ -369,21 +476,21 @@ def compute_loss(
369476
s_features_sliced = s_features_sliced.astype(jnp.float32)
370477
t_features_sliced = t_features_sliced.astype(jnp.float32)
371478

372-
feature_loss = self.beta_feature * self.feature_loss_fn(s_features_sliced, t_features_sliced)
479+
feature_loss = beta_feature * self.feature_loss_fn(s_features_sliced, t_features_sliced)
373480

374481
total_loss = base_logit_loss + feature_loss
375482

376-
# 4. Return Loss AND Metrics
483+
# 4. Return Loss AND Metrics (log dynamic values for TensorBoard verification)
377484
metrics = {
378485
"distill/soft_loss": soft_loss,
379486
"distill/hard_loss": hard_loss,
380487
"distill/kl_div": jnp.mean(kl_div, where=mean_mask),
381488
"distill/teacher_loss": teacher_hard_loss,
382489
"distill/out_proj_feature_loss": feature_loss,
383490
"distill/total_loss": total_loss,
384-
"distill/temperature": self.temperature,
385-
"distill/alpha": self.alpha,
386-
"distill/beta_feature": self.beta_feature,
491+
"distill/temperature": temperature,
492+
"distill/alpha": alpha,
493+
"distill/beta_feature": beta_feature,
387494
}
388495
return total_loss, metrics
389496

src/maxtext/trainers/post_train/distillation/train_distill.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ class ModelBundle(nnx.Module):
182182
def __init__(self, teacher_model: nnx.Module, student_model: nnx.Module):
183183
self.teacher_model = teacher_model
184184
self.student_model = student_model
185+
self.training_step = nnx.Variable(jnp.zeros((), dtype=jnp.int32))
185186

186187
def __call__(self, *args, **kwargs):
187188
raise NotImplementedError("Use `call_student` or `call_teacher` explicitly.")
@@ -269,6 +270,7 @@ def _train_step(self, model, optimizer, inputs):
269270
"""Overrides the main JIT block to natively handle ModelBundle module."""
270271

271272
batch = self.gen_model_input_fn(inputs)
273+
current_step = model.training_step.value
272274

273275
def loss_wrapper(student, teacher, batch):
274276
if "teacher_output" in batch:
@@ -299,7 +301,7 @@ def loss_wrapper(student, teacher, batch):
299301
)
300302
# we should apply a mask for labels to disable segment-separator tokens
301303
labels = self.strategy.create_labels(batch["targets"], targets_segmentation=batch.get("targets_segmentation", None))
302-
return self.strategy.compute_loss(student_output, teacher_output, labels)
304+
return self.strategy.compute_loss(student_output, teacher_output, labels, step=current_step)
303305

304306
# Because student is the 0th argument, argnums=0 guarantees
305307
# we only compute gradients for the student.
@@ -311,6 +313,9 @@ def loss_wrapper(student, teacher, batch):
311313

312314
out, grads = grad_fn(model.student_model, model.teacher_model, batch)
313315

316+
# Increment step counter after loss computation
317+
model.training_step.value = current_step + 1
318+
314319
tunix_expects_grad_norm = getattr(self, "_tunix_expects_grad_norm", True)
315320

316321
optimizer.update(model.student_model, grads)
@@ -509,6 +514,13 @@ def build_training_components(
509514
layer_indices=student_config.distill_layer_indices,
510515
feature_loss_type=student_config.distill_feature_loss_type,
511516
vocab_size=student_config.vocab_size,
517+
alpha_end=student_config.distill_alpha_end,
518+
alpha_schedule=student_config.distill_alpha_schedule,
519+
temperature_end=student_config.distill_temperature_end,
520+
temperature_schedule=student_config.distill_temperature_schedule,
521+
beta_end=student_config.distill_beta_end,
522+
beta_schedule=student_config.distill_beta_schedule,
523+
max_steps=student_config.steps,
512524
)
513525

514526
# 4. Optimizer & Config
@@ -632,6 +644,10 @@ def student_freeze_param_fn(path) -> bool:
632644
# Replace the default CheckpointManager with a Grain-aware one, which enables iterator checkpointing for grain datasets.
633645
raw_train_iter = trainer.setup_checkpoint_manager_and_restore(raw_train_iter, student_config)
634646

647+
# Sync the ModelBundle step counter with the restored training step so that
648+
# loss weight schedules resume from the correct position after checkpoint restore.
649+
model_bundle.training_step.set_value(jnp.array(trainer._train_steps, dtype=jnp.int32)) # pylint: disable=protected-access
650+
635651
# 6. Configure Input Mapping
636652
def custom_gen_model_input_fn(batch):
637653
inputs_dict = {

0 commit comments

Comments
 (0)