@@ -44,25 +44,25 @@ class DistillationForwardOutput:
4444 """Dataclass to carry MaxText-specific output fields."""
4545
4646 #: logits
47- logits : jax .Array = None
47+ logits : jax .Array
4848 #: out_projection_activations
49- out_projection_activations : jax .Array = None
49+ out_projection_activations : jax .Array | None = None
5050
5151
5252@flax .struct .dataclass (frozen = True )
5353class MaxTextTrainingInput (peft_trainer .TrainingInput ):
5454 """Extended TrainingInput dataclass to carry MaxText-specific fields."""
5555
5656 #: Position indices for the tokens (for RoPE).
57- positions : jax .Array = None
57+ positions : jax .Array | None = None
5858 #: Segment IDs for packed sequences (0=padding, 1+=examples).
59- decoder_segment_ids : jax .Array = None
59+ decoder_segment_ids : jax .Array | None = None
6060 #: Ground truth target tokens (used for loss calculation and logging).
61- targets : jax .Array = None
61+ targets : jax .Array | None = None
6262 #: Position indices for the target tokens.
63- targets_position : jax .Array = None
63+ targets_position : jax .Array | None = None
6464 #: Segment IDs for packed target tokens.
65- targets_segmentation : jax .Array = None
65+ targets_segmentation : jax .Array | None = None
6666
6767
6868# -----------------------------------------------------------------------------
@@ -222,7 +222,7 @@ def compute_loss(
222222 # 3. Combine losses
223223 base_logit_loss = (self .alpha * soft_loss ) + ((1.0 - self .alpha ) * hard_loss )
224224
225- feature_loss = 0.0
225+ feature_loss = jnp . array ( 0.0 )
226226 if self .beta_feature > 0.0 :
227227
228228 if self .layer_indices is not None :
@@ -364,6 +364,69 @@ def save(self, step, model, optimizer=None, save_only_lora_params=False, force=F
364364 force = force ,
365365 )
366366
367+ def maybe_restore (
368+ self ,
369+ model : Any ,
370+ optimizer : Any = None ,
371+ restore_only_lora_params : bool = False ,
372+ ) -> tuple [int , dict [str , Any ]]:
373+ """Restores model and optimizer state if a checkpoint exists, using correct sharding specs."""
374+ if self ._checkpoint_manager is None :
375+ return 0 , {}
376+
377+ step = self ._checkpoint_manager .latest_step ()
378+ if step is None :
379+ return 0 , {}
380+
381+ max_logging .log (f"Restoring from checkpoint step { step } ..." )
382+
383+ # Extract student model safely
384+ target_model = getattr (model , "student_model" , model )
385+
386+ if restore_only_lora_params :
387+ params = nnx .state (target_model , nnx .LoRAParam )
388+ else :
389+ params = nnx .state (target_model )
390+
391+ def map_to_pspec (data ):
392+ if hasattr (data , "sharding" ):
393+ return checkpoint .type_handlers .ArrayRestoreArgs (sharding = data .sharding )
394+ return None
395+
396+ restore_args = jax .tree .map (map_to_pspec , params )
397+
398+ cp_restore_args = {
399+ "model_params" : checkpoint .args .PyTreeRestore (
400+ item = params ,
401+ restore_args = restore_args ,
402+ )
403+ }
404+
405+ if optimizer is not None :
406+ optimizer_state = nnx .state (optimizer , nnx .optimizer .OptState )
407+ opt_restore_args = jax .tree .map (map_to_pspec , optimizer_state )
408+ cp_restore_args ["optimizer_state" ] = checkpoint .args .PyTreeRestore (
409+ item = optimizer_state ,
410+ restore_args = opt_restore_args ,
411+ )
412+
413+ restored = self ._checkpoint_manager .restore (
414+ step ,
415+ args = checkpoint .args .Composite (** cp_restore_args ),
416+ )
417+
418+ nnx .update (target_model , restored .model_params )
419+ if optimizer is not None :
420+ nnx .update (optimizer , restored .optimizer_state )
421+
422+ metadata = self ._checkpoint_manager .metadata (step )
423+ if metadata and hasattr (metadata , "custom_metadata" ) and metadata .custom_metadata is not None :
424+ custom_metadata = metadata .custom_metadata
425+ else :
426+ custom_metadata = {}
427+
428+ return step , dict (custom_metadata )
429+
367430 def restore_iterator (self ):
368431 """Restores the iterator using MaxText's logic."""
369432 if self ._checkpoint_manager is None or self ._iterator is None :
0 commit comments