Skip to content

Commit ed51fc7

Browse files
committed
Refactored the distillation input pipeline and checkpoint manager to support state restoration via maybe_restore.
1 parent 161f69a commit ed51fc7

3 files changed

Lines changed: 504 additions & 154 deletions

File tree

src/maxtext/trainers/post_train/distillation/distillation_utils.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,25 @@ class DistillationForwardOutput:
4444
"""Dataclass to carry MaxText-specific output fields."""
4545

4646
#: logits
47-
logits: jax.Array = None
47+
logits: jax.Array
4848
#: out_projection_activations
49-
out_projection_activations: jax.Array = None
49+
out_projection_activations: jax.Array | None = None
5050

5151

5252
@flax.struct.dataclass(frozen=True)
5353
class MaxTextTrainingInput(peft_trainer.TrainingInput):
5454
"""Extended TrainingInput dataclass to carry MaxText-specific fields."""
5555

5656
#: Position indices for the tokens (for RoPE).
57-
positions: jax.Array = None
57+
positions: jax.Array | None = None
5858
#: Segment IDs for packed sequences (0=padding, 1+=examples).
59-
decoder_segment_ids: jax.Array = None
59+
decoder_segment_ids: jax.Array | None = None
6060
#: Ground truth target tokens (used for loss calculation and logging).
61-
targets: jax.Array = None
61+
targets: jax.Array | None = None
6262
#: Position indices for the target tokens.
63-
targets_position: jax.Array = None
63+
targets_position: jax.Array | None = None
6464
#: Segment IDs for packed target tokens.
65-
targets_segmentation: jax.Array = None
65+
targets_segmentation: jax.Array | None = None
6666

6767

6868
# -----------------------------------------------------------------------------
@@ -222,7 +222,7 @@ def compute_loss(
222222
# 3. Combine losses
223223
base_logit_loss = (self.alpha * soft_loss) + ((1.0 - self.alpha) * hard_loss)
224224

225-
feature_loss = 0.0
225+
feature_loss = jnp.array(0.0)
226226
if self.beta_feature > 0.0:
227227

228228
if self.layer_indices is not None:
@@ -364,6 +364,69 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F
364364
force=force,
365365
)
366366

367+
def maybe_restore(
368+
self,
369+
model: Any,
370+
optimizer: Any = None,
371+
restore_only_lora_params: bool = False,
372+
) -> tuple[int, dict[str, Any]]:
373+
"""Restores model and optimizer state if a checkpoint exists, using correct sharding specs."""
374+
if self._checkpoint_manager is None:
375+
return 0, {}
376+
377+
step = self._checkpoint_manager.latest_step()
378+
if step is None:
379+
return 0, {}
380+
381+
max_logging.log(f"Restoring from checkpoint step {step}...")
382+
383+
# Extract student model safely
384+
target_model = getattr(model, "student_model", model)
385+
386+
if restore_only_lora_params:
387+
params = nnx.state(target_model, nnx.LoRAParam)
388+
else:
389+
params = nnx.state(target_model)
390+
391+
def map_to_pspec(data):
392+
if hasattr(data, "sharding"):
393+
return checkpoint.type_handlers.ArrayRestoreArgs(sharding=data.sharding)
394+
return None
395+
396+
restore_args = jax.tree.map(map_to_pspec, params)
397+
398+
cp_restore_args = {
399+
"model_params": checkpoint.args.PyTreeRestore(
400+
item=params,
401+
restore_args=restore_args,
402+
)
403+
}
404+
405+
if optimizer is not None:
406+
optimizer_state = nnx.state(optimizer, nnx.optimizer.OptState)
407+
opt_restore_args = jax.tree.map(map_to_pspec, optimizer_state)
408+
cp_restore_args["optimizer_state"] = checkpoint.args.PyTreeRestore(
409+
item=optimizer_state,
410+
restore_args=opt_restore_args,
411+
)
412+
413+
restored = self._checkpoint_manager.restore(
414+
step,
415+
args=checkpoint.args.Composite(**cp_restore_args),
416+
)
417+
418+
nnx.update(target_model, restored.model_params)
419+
if optimizer is not None:
420+
nnx.update(optimizer, restored.optimizer_state)
421+
422+
metadata = self._checkpoint_manager.metadata(step)
423+
if metadata and hasattr(metadata, "custom_metadata") and metadata.custom_metadata is not None:
424+
custom_metadata = metadata.custom_metadata
425+
else:
426+
custom_metadata = {}
427+
428+
return step, dict(custom_metadata)
429+
367430
def restore_iterator(self):
368431
"""Restores the iterator using MaxText's logic."""
369432
if self._checkpoint_manager is None or self._iterator is None:

0 commit comments

Comments
 (0)