@@ -44,25 +44,25 @@ class DistillationForwardOutput:
4444 """Dataclass to carry MaxText-specific output fields."""
4545
4646 #: logits
47- logits : jax .Array = None
47+ logits : jax .Array
4848 #: out_projection_activations
49- out_projection_activations : jax .Array = None
49+ out_projection_activations : jax .Array | None = None
5050
5151
5252@flax .struct .dataclass (frozen = True )
5353class MaxTextTrainingInput (peft_trainer .TrainingInput ):
5454 """Extended TrainingInput dataclass to carry MaxText-specific fields."""
5555
5656 #: Position indices for the tokens (for RoPE).
57- positions : jax .Array = None
57+ positions : jax .Array | None = None
5858 #: Segment IDs for packed sequences (0=padding, 1+=examples).
59- decoder_segment_ids : jax .Array = None
59+ decoder_segment_ids : jax .Array | None = None
6060 #: Ground truth target tokens (used for loss calculation and logging).
61- targets : jax .Array = None
61+ targets : jax .Array | None = None
6262 #: Position indices for the target tokens.
63- targets_position : jax .Array = None
63+ targets_position : jax .Array | None = None
6464 #: Segment IDs for packed target tokens.
65- targets_segmentation : jax .Array = None
65+ targets_segmentation : jax .Array | None = None
6666
6767
6868# -----------------------------------------------------------------------------
@@ -222,7 +222,7 @@ def compute_loss(
222222 # 3. Combine losses
223223 base_logit_loss = (self .alpha * soft_loss ) + ((1.0 - self .alpha ) * hard_loss )
224224
225- feature_loss = 0.0
225+ feature_loss = jnp . array ( 0.0 )
226226 if self .beta_feature > 0.0 :
227227
228228 if self .layer_indices is not None :
@@ -364,6 +364,86 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F
364364 force = force ,
365365 )
366366
367+ def maybe_restore (
368+ self ,
369+ model : Any ,
370+ optimizer : Any = None ,
371+ restore_only_lora_params : bool = False ,
372+ ) -> tuple [int , dict [str , Any ]]:
373+ """Restores model and optimizer state if a checkpoint exists, using correct sharding specs.
374+
375+ This method checks for the latest available checkpoint. If found, it restores the
376+ model parameters and optionally the optimizer state in-place. It automatically
377+ maps the parameter's `sharding` attributes to Orbax restore arguments to ensure
378+ the tensors are placed on the correct device meshes.
379+
380+ Args:
381+ model: The model to restore. If a `ModelBundle` is provided, it automatically
382+ extracts and restores only the `student_model`.
383+ optimizer: The optimizer state to restore. If None, optimizer restoration is skipped.
384+ restore_only_lora_params: If True, restricts restoration to parameters marked
385+ as `nnx.LoRAParam`.
386+
387+ Returns:
388+ A tuple containing the restored step number (0 if no checkpoint was found)
389+ and a dictionary of custom metadata.
390+ """
391+ if self ._checkpoint_manager is None :
392+ return 0 , {}
393+
394+ step = self ._checkpoint_manager .latest_step ()
395+ if step is None :
396+ return 0 , {}
397+
398+ max_logging .log (f"Restoring from checkpoint step { step } ..." )
399+
400+ # Extract student model safely
401+ target_model = getattr (model , "student_model" , model )
402+
403+ if restore_only_lora_params :
404+ params = nnx .state (target_model , nnx .LoRAParam )
405+ else :
406+ params = nnx .state (target_model )
407+
408+ def map_to_pspec (data ):
409+ if hasattr (data , "sharding" ):
410+ return checkpoint .type_handlers .ArrayRestoreArgs (sharding = data .sharding )
411+ return None
412+
413+ restore_args = jax .tree .map (map_to_pspec , params )
414+
415+ cp_restore_args = {
416+ "model_params" : checkpoint .args .PyTreeRestore (
417+ item = params ,
418+ restore_args = restore_args ,
419+ )
420+ }
421+
422+ if optimizer is not None :
423+ optimizer_state = nnx .state (optimizer , nnx .optimizer .OptState )
424+ opt_restore_args = jax .tree .map (map_to_pspec , optimizer_state )
425+ cp_restore_args ["optimizer_state" ] = checkpoint .args .PyTreeRestore (
426+ item = optimizer_state ,
427+ restore_args = opt_restore_args ,
428+ )
429+
430+ restored = self ._checkpoint_manager .restore (
431+ step ,
432+ args = checkpoint .args .Composite (** cp_restore_args ),
433+ )
434+
435+ nnx .update (target_model , restored .model_params )
436+ if optimizer is not None :
437+ nnx .update (optimizer , restored .optimizer_state )
438+
439+ metadata = self ._checkpoint_manager .metadata (step )
440+ if metadata and hasattr (metadata , "custom_metadata" ) and metadata .custom_metadata is not None :
441+ custom_metadata = metadata .custom_metadata
442+ else :
443+ custom_metadata = {}
444+
445+ return step , dict (custom_metadata )
446+
367447 def restore_iterator (self ):
368448 """Restores the iterator using MaxText's logic."""
369449 if self ._checkpoint_manager is None or self ._iterator is None :
0 commit comments