Skip to content

Commit 7638c8f

Browse files
committed
Refactored the distillation input pipeline and checkpoint manager to support state restoration via maybe_restore.
1 parent 161f69a commit 7638c8f

3 files changed

Lines changed: 545 additions & 159 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,25 @@ class DistillationForwardOutput:
4444
"""Dataclass to carry MaxText-specific output fields."""
4545

4646
#: logits
47-
logits: jax.Array = None
47+
logits: jax.Array
4848
#: out_projection_activations
49-
out_projection_activations: jax.Array = None
49+
out_projection_activations: jax.Array | None = None
5050

5151

5252
@flax.struct.dataclass(frozen=True)
5353
class MaxTextTrainingInput(peft_trainer.TrainingInput):
5454
"""Extended TrainingInput dataclass to carry MaxText-specific fields."""
5555

5656
#: Position indices for the tokens (for RoPE).
57-
positions: jax.Array = None
57+
positions: jax.Array | None = None
5858
#: Segment IDs for packed sequences (0=padding, 1+=examples).
59-
decoder_segment_ids: jax.Array = None
59+
decoder_segment_ids: jax.Array | None = None
6060
#: Ground truth target tokens (used for loss calculation and logging).
61-
targets: jax.Array = None
61+
targets: jax.Array | None = None
6262
#: Position indices for the target tokens.
63-
targets_position: jax.Array = None
63+
targets_position: jax.Array | None = None
6464
#: Segment IDs for packed target tokens.
65-
targets_segmentation: jax.Array = None
65+
targets_segmentation: jax.Array | None = None
6666

6767

6868
# -----------------------------------------------------------------------------
@@ -222,7 +222,7 @@ def compute_loss(
222222
# 3. Combine losses
223223
base_logit_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss)
224224

225-
feature_loss = 0.0
225+
feature_loss = jnp.array(0.0)
226226
if self.beta_feature > 0.0:
227227

228228
if self.layer_indices is not None:
@@ -364,6 +364,86 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F
364364
force=force,
365365
)
366366

367+
def maybe_restore(
368+
self,
369+
model: Any,
370+
optimizer: Any = None,
371+
restore_only_lora_params: bool = False,
372+
) -> tuple[int, dict[str, Any]]:
373+
"""Restores model and optimizer state if a checkpoint exists, using correct sharding specs.
374+
375+
This method checks for the latest available checkpoint. If found, it restores the
376+
model parameters and optionally the optimizer state in-place. It automatically
377+
maps the parameter's `sharding` attributes to Orbax restore arguments to ensure
378+
the tensors are placed on the correct device meshes.
379+
380+
Args:
381+
model: The model to restore. If a `ModelBundle` is provided, it automatically
382+
extracts and restores only the `student_model`.
383+
optimizer: The optimizer state to restore. If None, optimizer restoration is skipped.
384+
restore_only_lora_params: If True, restricts restoration to parameters marked
385+
as `nnx.LoRAParam`.
386+
387+
Returns:
388+
A tuple containing the restored step number (0 if no checkpoint was found)
389+
and a dictionary of custom metadata.
390+
"""
391+
if self._checkpoint_manager is None:
392+
return 0, {}
393+
394+
step = self._checkpoint_manager.latest_step()
395+
if step is None:
396+
return 0, {}
397+
398+
max_logging.log(f"Restoring from checkpoint step {step}...")
399+
400+
# Extract student model safely
401+
target_model = getattr(model, "student_model", model)
402+
403+
if restore_only_lora_params:
404+
params = nnx.state(target_model, nnx.LoRAParam)
405+
else:
406+
params = nnx.state(target_model)
407+
408+
def map_to_pspec(data):
409+
if hasattr(data, "sharding"):
410+
return checkpoint.type_handlers.ArrayRestoreArgs(sharding=data.sharding)
411+
return None
412+
413+
restore_args = jax.tree.map(map_to_pspec, params)
414+
415+
cp_restore_args = {
416+
"model_params": checkpoint.args.PyTreeRestore(
417+
item=params,
418+
restore_args=restore_args,
419+
)
420+
}
421+
422+
if optimizer is not None:
423+
optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState)
424+
opt_restore_args = jax.tree.map(map_to_pspec, optimizer_state)
425+
cp_restore_args["optimizer_state"] = checkpoint.args.PyTreeRestore(
426+
item=optimizer_state,
427+
restore_args=opt_restore_args,
428+
)
429+
430+
restored = self._checkpoint_manager.restore(
431+
step,
432+
args=checkpoint.args.Composite(**cp_restore_args),
433+
)
434+
435+
nnx.update(target_model, restored.model_params)
436+
if optimizer is not None:
437+
nnx.update(optimizer, restored.optimizer_state)
438+
439+
metadata = self._checkpoint_manager.metadata(step)
440+
if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:
441+
custom_metadata = metadata.custom_metadata
442+
else:
443+
custom_metadata = {}
444+
445+
return step, dict(custom_metadata)
446+
367447
def restore_iterator(self):
368448
"""Restores the iterator using MaxText's logic."""
369449
if self._checkpoint_manager is None or self._iterator is None:

0 commit comments

Comments
 (0)