@@ -463,13 +463,61 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, kv_caches
463463
464464 use_kv = kv_caches_stacked is not None
465465
466+ def stash_origin_metadata (x ):
467+ is_var = hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
468+ if is_var :
469+ metadata = x .get_metadata ()
470+ updates = {'origin_shape' : x .value .shape }
471+ for k in ["sharding" , "out_sharding" , "sharding_names" ]:
472+ if k in metadata :
473+ updates [f'origin_{ k } ' ] = metadata [k ]
474+ return x .replace (** updates )
475+ return x
476+
477+ params = jax .tree .map (stash_origin_metadata , params )
478+ state = jax .tree .map (stash_origin_metadata , state )
479+
466480 def layer_fn (carry , scanned_vars ):
467481 if use_kv :
468482 current_params , current_state , kv_cache_layer = scanned_vars
469483 else :
470484 current_params , current_state = scanned_vars
471485 kv_cache_layer = None
472486
487+ def rank_consistent_spec (spec , shape ):
488+ if spec is None :
489+ return None
490+ spec_list = list (spec )
491+ if len (spec_list ) > len (shape ):
492+ for axis_name in ["layers" , "stage" ]:
493+ if axis_name in spec_list :
494+ spec_list .remove (axis_name )
495+ if len (spec_list ) == len (shape ):
496+ break
497+ while len (spec_list ) > len (shape ):
498+ spec_list .pop (0 )
499+ while len (spec_list ) < len (shape ):
500+ spec_list .insert (0 , None )
501+ return jax .sharding .PartitionSpec (* spec_list )
502+
503+ def fix_node_rank (x ):
504+ if hasattr (x , "get_metadata" ) and hasattr (x , "replace" ) and hasattr (x , "value" ):
505+ metadata = x .get_metadata ()
506+ updates = {}
507+ for k , axes in metadata .items ():
508+ if isinstance (axes , (jax .sharding .PartitionSpec , tuple , list )):
509+ spec_obj = jax .sharding .PartitionSpec (* axes ) if isinstance (axes , (tuple , list )) else axes
510+ if len (spec_obj ) != x .value .ndim :
511+ new_spec = rank_consistent_spec (spec_obj , x .value .shape )
512+ updates [k ] = tuple (new_spec ) if isinstance (axes , (tuple , list )) else new_spec
513+ if updates :
514+ return x .replace (** updates )
515+ return x
516+
517+ is_nnx_var = lambda x : hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
518+ current_params = jax .tree .map (fix_node_rank , current_params , is_leaf = is_nnx_var )
519+ current_state = jax .tree .map (fix_node_rank , current_state , is_leaf = is_nnx_var )
520+
473521 if self .config .parameter_memory_host_offload :
474522 current_params = jax .tree .map (lambda x : jax .device_put (x , max_utils .device_space ()), current_params )
475523
@@ -540,8 +588,43 @@ def _ensure_scan_leading_axis(x):
540588 if scan_axis != 0 :
541589 scanned_params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), scanned_params )
542590
591+ def restore_origin_metadata (x ):
592+ is_var = hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
593+ if is_var :
594+ metadata = x .get_metadata ()
595+ updates = {}
596+ for k in ["sharding" , "out_sharding" , "sharding_names" ]:
597+ origin_key = f'origin_{ k } '
598+ if origin_key in metadata :
599+ updates [k ] = metadata [origin_key ]
600+ else :
601+ axes = metadata .get (k )
602+ if isinstance (axes , (jax .sharding .PartitionSpec , tuple , list )):
603+ spec_list = list (axes )
604+ if "layers" not in spec_list :
605+ pos = min (self .config .param_scan_axis , len (spec_list ))
606+ spec_list .insert (pos , "layers" )
607+ new_spec = jax .sharding .PartitionSpec (* spec_list )
608+ updates [k ] = tuple (new_spec ) if isinstance (axes , (tuple , list )) else new_spec
609+ if updates :
610+ return x .replace (** updates )
611+ return x
612+
613+ is_leaf_with_metadata = lambda x : hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
614+ scanned_params = jax .tree .map (restore_origin_metadata , scanned_params , is_leaf = is_leaf_with_metadata )
615+ scanned_other = jax .tree .map (restore_origin_metadata , scanned_other , is_leaf = is_leaf_with_metadata )
616+
543617 if dynamic_graph_init :
544618 out_layers = nnx .merge (updated_graphdef [0 ], scanned_params , scanned_other )
619+
620+ for attr_name , attr_val in self .__dict__ .items ():
621+ if attr_val is layers :
622+ setattr (self , attr_name , out_layers )
623+ break
624+
625+ g , s = nnx .split (self )
626+ new_self = nnx .merge (g , s )
627+ nnx .update (self , nnx .state (new_self ))
545628 else :
546629 nnx .update (layers , nnx .State .merge (scanned_params , scanned_other ))
547630 out_layers = layers
0 commit comments