@@ -449,13 +449,57 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
449449 def layer_fn (carry , scanned_vars ):
450450 current_params , current_state = scanned_vars
451451
452+ def rank_consistent_spec (spec , shape ):
453+ if spec is None : return None
454+ spec_list = list (spec )
455+
456+ # 1. Remove scanning axes if rank reduction is needed
457+ if len (spec_list ) > len (shape ):
458+ for axis_name in ["layers" , "stage" ]:
459+ if axis_name in spec_list :
460+ spec_list .remove (axis_name )
461+ if len (spec_list ) == len (shape ): break
462+
463+ # 2. If still mismatched, strip from the left (standard JAX rank reduction)
464+ while len (spec_list ) > len (shape ):
465+ spec_list .pop (0 )
466+
467+ # 3. If rank is too small, pad with None
468+ while len (spec_list ) < len (shape ):
469+ spec_list .insert (0 , None )
470+
471+ return jax .sharding .PartitionSpec (* spec_list )
472+
473+ def fix_node_rank (x ):
474+ if hasattr (x , "get_metadata" ) and hasattr (x , "replace" ) and hasattr (x , "value" ):
475+ metadata = x .get_metadata ()
476+ updates = {}
477+ for k , axes in metadata .items ():
478+ if isinstance (axes , (jax .sharding .PartitionSpec , tuple , list )):
479+ # Convert tuple/list to spec for check
480+ spec_obj = jax .sharding .PartitionSpec (* axes ) if isinstance (axes , (tuple , list )) else axes
481+ if len (spec_obj ) != x .value .ndim :
482+ new_spec = rank_consistent_spec (spec_obj , x .value .shape )
483+ # Keep original type (tuple vs spec)
484+ updates [k ] = tuple (new_spec ) if isinstance (axes , (tuple , list )) else new_spec
485+ # print(f"[DEBUG] Normalizing metadata key '{k}' from rank {len(spec_obj)} to {len(new_spec)}")
486+ if updates :
487+ return x .replace (** updates )
488+ return x
489+
490+ is_nnx_var = lambda x : hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
491+ current_params = jax .tree .map (fix_node_rank , current_params , is_leaf = is_nnx_var )
492+ current_state = jax .tree .map (fix_node_rank , current_state , is_leaf = is_nnx_var )
493+
452494 if self .config .parameter_memory_host_offload :
453495 current_params = jax .tree .map (lambda x : jax .device_put (x , max_utils .device_space ()), current_params )
454496
455497 layer = nnx .merge (graphdef , current_params , current_state )
498+
456499 layer_out = layer (carry , * args , ** valid_kwargs )
457500 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
458501
502+ # Extract EVERYTHING to capture new parameters
459503 new_graphdef , updated_params , updated_state = nnx .split (layer , nnx .Param , ...)
460504
461505 if dynamic_graph_init :
@@ -466,23 +510,154 @@ def layer_fn(carry, scanned_vars):
466510
467511 return new_carry , (returned_params , updated_state )
468512
469- layer_fn_wrapped = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
513+ if dynamic_graph_init :
514+ print (f"[DEBUG] Starting Dynamic Graph Init Loop (length={ length } )" )
515+ curr_carry = x_in
516+ out_params_list = []
517+ out_other_list = []
518+
519+ def _slice_and_unpromote (x , i ):
520+ # Resolve physical value and shape
521+ is_var = hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
522+ val = x .value if is_var else x
523+
524+ if not hasattr (val , "shape" ) or len (val .shape ) == 0 or val .shape [0 ] != length :
525+ return x
526+
527+ # 1. Slice value
528+ sliced_val = val [i ]
529+
530+ # 2. Slice logical metadata if it's an NNX variable
531+ if is_var :
532+ metadata = x .get_metadata ()
533+ updates = {}
534+ for sharding_key in ["sharding" , "out_sharding" , "sharding_names" ]:
535+ axes = metadata .get (sharding_key )
536+ if isinstance (axes , jax .sharding .PartitionSpec ):
537+ spec_list = list (axes )
538+
539+ # Aggressively reduce rank to match sliced_val.ndim
540+ for axis_to_remove in ["layers" , "stage" ]:
541+ if axis_to_remove in spec_list and len (spec_list ) > sliced_val .ndim :
542+ spec_list .remove (axis_to_remove )
543+
544+ while len (spec_list ) > sliced_val .ndim :
545+ spec_list .pop (0 )
546+
547+ while len (spec_list ) < sliced_val .ndim :
548+ spec_list .insert (0 , None )
549+
550+ new_spec = jax .sharding .PartitionSpec (* spec_list )
551+ updates [sharding_key ] = new_spec
552+
553+ return x .replace (value = sliced_val , ** updates )
554+
555+ return sliced_val
556+
557+ def _promote_to_scanned (x ):
558+ """Adds 'layers' axis back to newly created parameters if scanning is enabled."""
559+ if not self .config .scan_layers :
560+ return x
561+
562+ is_nnx_leaf = lambda x : hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
563+ if is_nnx_leaf (x ):
564+ metadata = x .get_metadata ()
565+ updates = {}
566+ # Determine which axis to insert 'layers' into based on config
567+ scan_axis = self .config .param_scan_axis
568+
569+ for sharding_key in ["sharding" , "out_sharding" , "sharding_names" ]:
570+ axes = metadata .get (sharding_key )
571+ if isinstance (axes , jax .sharding .PartitionSpec ):
572+ spec_list = list (axes )
573+ if "layers" not in spec_list :
574+ # Insert 'layers' at the correct scan axis position
575+ # Cap at current length to avoid index out of bounds
576+ insert_pos = min (scan_axis , len (spec_list ))
577+ spec_list .insert (insert_pos , "layers" )
578+ updates [sharding_key ] = jax .sharding .PartitionSpec (* spec_list )
579+
580+ if updates :
581+ return x .replace (** updates )
582+ return x
583+
584+ for i in range (length ):
585+ # Slice both values AND logical metadata!
586+ is_nnx_leaf = lambda x : hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
587+ curr_params = jax .tree .map (lambda x : _slice_and_unpromote (x , i ), params , is_leaf = is_nnx_leaf )
588+ curr_state = jax .tree .map (lambda x : _slice_and_unpromote (x , i ), state , is_leaf = is_nnx_leaf )
589+
590+ curr_carry , (out_p , out_o ) = layer_fn (curr_carry , (curr_params , curr_state ))
591+
592+ # Promote ALL parameters back to rank-3 metadata immediately
593+ # This ensures they are ready to be stacked correctly.
594+ out_p = jax .tree .map (_promote_to_scanned , out_p , is_leaf = is_nnx_leaf )
595+ out_o = jax .tree .map (_promote_to_scanned , out_o , is_leaf = is_nnx_leaf )
596+
597+ out_params_list .append (out_p )
598+ out_other_list .append (out_o )
599+
600+ final_carry = curr_carry
601+ scanned_params = jax .tree .map (lambda * args : jnp .stack (args ), * out_params_list )
602+ scanned_other = jax .tree .map (lambda * args : jnp .stack (args ), * out_other_list )
603+
604+
605+ else :
606+ layer_fn_wrapped = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
470607
471- def _ensure_scan_leading_axis (x ):
472- if not hasattr (x , "shape" ) or len (x .shape ) == 0 :
473- return jnp .broadcast_to (x , (length ,))
474- return x
608+ def _ensure_scan_leading_axis (x ):
609+ if not hasattr (x , "shape" ) or len (x .shape ) == 0 :
610+ return jnp .broadcast_to (x , (length ,))
611+ return x
475612
476- params = jax .tree .map (_ensure_scan_leading_axis , params )
477- state = jax .tree .map (_ensure_scan_leading_axis , state )
613+ params = jax .tree .map (_ensure_scan_leading_axis , params )
614+ state = jax .tree .map (_ensure_scan_leading_axis , state )
478615
479- final_carry , (scanned_params , scanned_other ) = jax .lax .scan (layer_fn_wrapped , x_in , (params , state ))
616+ final_carry , (scanned_params , scanned_other ) = jax .lax .scan (layer_fn_wrapped , x_in , (params , state ))
480617
481618 if scan_axis != 0 :
482619 scanned_params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), scanned_params )
483620
621+ scan_axis = self .config .param_scan_axis
622+
623+ def _force_promote (x ):
624+ is_nnx_leaf = hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
625+ if is_nnx_leaf :
626+ metadata = x .get_metadata ()
627+ updates = {}
628+ val_ndim = x .value .ndim
629+ for sharding_key in ["sharding" , "out_sharding" , "sharding_names" ]:
630+ axes = metadata .get (sharding_key )
631+ if isinstance (axes , (jax .sharding .PartitionSpec , tuple , list )):
632+ l = list (axes )
633+ if len (l ) < val_ndim and "layers" not in l :
634+ pos = min (scan_axis , len (l ))
635+ l .insert (pos , "layers" )
636+ updates [sharding_key ] = jax .sharding .PartitionSpec (* l ) if isinstance (axes , jax .sharding .PartitionSpec ) else tuple (l )
637+ if updates :
638+ return x .replace (** updates )
639+ return x
640+
641+ is_leaf_with_metadata = lambda x : hasattr (x , "get_metadata" ) and hasattr (x , "replace" )
642+ scanned_params = jax .tree .map (_force_promote , scanned_params , is_leaf = is_leaf_with_metadata )
643+ scanned_other = jax .tree .map (_force_promote , scanned_other , is_leaf = is_leaf_with_metadata )
644+
484645 if dynamic_graph_init :
646+ # Perform a structural update: merge the new structure with the stacked arrays
485647 out_layers = nnx .merge (updated_graphdef [0 ], scanned_params , scanned_other )
648+
649+ # We must update the PARENT (self) to point to the new structure.
650+ for attr_name , attr_val in self .__dict__ .items ():
651+ if attr_val is layers :
652+ setattr (self , attr_name , out_layers )
653+ print (f"[DEBUG] Materialization complete: updated self.{ attr_name } " )
654+ break
655+
656+ # FORCE NNX to recognize new structural changes by splitting/merging the PARENT
657+ # This updates the underlying GraphDef for the entire Decoder.
658+ g , s = nnx .split (self )
659+ new_self = nnx .merge (g , s )
660+ nnx .update (self , nnx .state (new_self ))
486661 else :
487662 nnx .update (layers , nnx .State .merge (scanned_params , scanned_other ))
488663 out_layers = layers
0 commit comments