@@ -303,6 +303,7 @@ def layer_fn(carry, scanned_vars):
303303 layer = nnx .merge (graphdef , current_params , current_state )
304304 layer_out = layer (carry , decoder_segment_ids , decoder_positions , deterministic , model_mode , ** kwargs )
305305 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
306+ nnx .pop (layer , nnx .Intermediate )
306307 return new_carry , nnx .state (layer )
307308
308309 final_carry , scanned_state = jax .lax .scan (layer_fn , inputs , (params , state ))
@@ -534,6 +535,8 @@ def _create_scanned_layers(
534535 self , decoder_layer_class , length : int , metadata_axis_name : str , rngs : nnx .Rngs , ** layer_kwargs
535536 ):
536537 """Creates a VMapped stack of layers, forcing parameter init for Compact modules."""
538+ if length == 0 :
539+ return nnx .List ([])
537540
538541 def create_layer_fn (rng ):
539542 return decoder_layer_class (
@@ -566,13 +569,17 @@ def pure_layer_fn(state_in, y_in):
566569 out = merged_layer (y_in , ** kwargs )
567570 return out , nnx .state (merged_layer )
568571
569- checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
570- out , new_state = checkpointed_fn (state , y )
572+ if not self ._uses_linen_fp8_ops ():
573+ pure_layer_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
574+ out , new_state = pure_layer_fn (state , y )
571575 nnx .update (layer , new_state )
572576 return out
573577
574578 def _apply_layers_sequentially (self , layers , x_in , * args , length : int , ** kwargs ):
575579 """Runs the layer stack using nnx.scan."""
580+ if length == 0 :
581+ _ , empty_state = nnx .split (layers )
582+ return x_in , empty_state
576583 policy = self .get_remat_policy ()
577584 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config )
578585 graphdef , params , state = nnx .split (
@@ -608,7 +615,25 @@ def layer_fn(carry, scanned_vars):
608615 # Run the layer (Filter kwargs if using the solution from previous turn)
609616 layer_out = layer (carry , * args , ** valid_kwargs )
610617 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
611- return new_carry , nnx .state (layer )
618+ nnx .pop (layer , nnx .Intermediate )
619+ new_current_state = nnx .state (layer )
620+ return new_carry , new_current_state
621+
622+ if self ._uses_linen_fp8_ops ():
623+ # jax.lax.scan is incompatible with Linen fp8 ops: put_variable in setup() stores
624+ # scan-level tracers as Python attributes on the Linen module, causing a tracer leak
625+ # across the scan boundary. Fall back to a Python loop instead.
626+ x = x_in
627+ for i in range (length ):
628+ params_i = jax .tree .map (lambda p , _i = i : p [_i ], params )
629+ state_i = jax .tree .map (lambda s , _i = i : s [_i ], state )
630+ layer = nnx .merge (graphdef , params_i , state_i )
631+ layer_out = layer (x , * args , ** valid_kwargs )
632+ x = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
633+ nnx .pop (layer , nnx .Intermediate )
634+ if scan_axis != 0 :
635+ params = jax .tree .map (lambda p : jnp .moveaxis (p , 0 , scan_axis ), params )
636+ return x , nnx .State .merge (params , state )
612637
613638 layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
614639 final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
@@ -672,7 +697,8 @@ def get_chunk(pytree, start, end):
672697 layer_out = layer (y , * layer_args , ** valid_kwargs )
673698 y = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
674699
675- _ , new_eng_mutables = nnx .split (layer , nnx .Param , ...)
700+ nnx .pop (layer , nnx .Intermediate )
701+ _ , _ , new_eng_mutables = nnx .split (layer , nnx .Param , ...)
676702 new_eng_mutables = jax .tree .map (lambda x : jnp .expand_dims (x , axis = 0 ), new_eng_mutables )
677703 updated_mutables_chunks .append (new_eng_mutables )
678704 current_idx += 1
@@ -698,10 +724,12 @@ def layer_fn(carry, scanned_vars):
698724 l = nnx .merge (graphdef , curr_p , curr_m )
699725 l_out = l (carry , * layer_args , ** valid_kwargs )
700726 n_carry = l_out [0 ] if isinstance (l_out , tuple ) else l_out
701- _ , n_mut = nnx .split (l , nnx .Param , ...)
727+ nnx .pop (l , nnx .Intermediate )
728+ _ , _ , n_mut = nnx .split (l , nnx .Param , ...)
702729 return n_carry , n_mut
703730
704- layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
731+ if not self ._uses_linen_fp8_ops ():
732+ layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
705733 y , new_chunk_mutables = jax .lax .scan (layer_fn , y , (chunk_params , chunk_mutables ))
706734 updated_mutables_chunks .append (new_chunk_mutables )
707735 current_idx = next_boundary
@@ -742,7 +770,11 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in, dynamic_kwargs):
742770 out_y , out_kv = merged_layer (y_in , * layer_args , kv_cache = kv_in , ** dynamic_kwargs )
743771 return out_y , out_kv , nnx .state (merged_layer )
744772
745- checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
773+ checkpointed_fn = (
774+ pure_layer_fn
775+ if self ._uses_linen_fp8_ops ()
776+ else jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
777+ )
746778
747779 for lyr in range (num_layers ):
748780 attr_name = f"{ base_name } _{ lyr } "
@@ -921,6 +953,10 @@ def get_remat_policy(self):
921953 assert cfg .remat_policy == "full" , "Remat policy needs to be on list of remat policies"
922954 return policy
923955
956+ def _uses_linen_fp8_ops (self ) -> bool :
957+ """Returns True if the quantization mode uses Linen fp8 ops incompatible with jax.checkpoint."""
958+ return self .config .quantization in ("fp8_gpu" , "fp8_nanoo" )
959+
924960 def get_norm_layer (self , num_features : int , rngs : nnx .Rngs ):
925961 """Helper to retrieve the correct normalization layer class based on config, partially applied with common arguments."""
926962 if self .config .decoder_block in (
@@ -1072,10 +1108,18 @@ def __call__(
10721108 audio_embeddings : None | jnp .ndarray = None ,
10731109 audio_masks : None | jnp .ndarray = None ,
10741110 deepstack_visual_embeds : None | list [jnp .ndarray ] = None ,
1111+ multimodal_input = None ,
10751112 ):
10761113 cfg = self .config
10771114 assert decoder_input_tokens .ndim == 2 # [batch, len]
10781115
1116+ if multimodal_input is not None :
1117+ image_embeddings = multimodal_input .image_embeddings
1118+ bidirectional_mask = multimodal_input .bidirectional_mask
1119+ image_masks = multimodal_input .image_masks
1120+ audio_embeddings = multimodal_input .audio_embeddings
1121+ audio_masks = multimodal_input .audio_masks
1122+
10791123 # [batch, length] -> [batch, length, emb_dim]
10801124 y = self ._apply_embedding (
10811125 shared_embedding ,
@@ -1119,12 +1163,20 @@ def __call__(
11191163 if cfg .scan_layers :
11201164 if cfg .engram_layers :
11211165 y , self .dense_layers = self ._apply_interleaved_scanned_layers (
1122- self .dense_layers , y , layer_args , layer_kwargs ,
1123- start_idx = 0 , end_idx = cfg .first_num_dense_layers ,
1124- engram_indices = cfg .engram_layers , decoder_input_tokens = decoder_input_tokens ,
1166+ self .dense_layers ,
1167+ y ,
1168+ layer_args ,
1169+ layer_kwargs ,
1170+ start_idx = 0 ,
1171+ end_idx = cfg .first_num_dense_layers ,
1172+ engram_indices = cfg .engram_layers ,
1173+ decoder_input_tokens = decoder_input_tokens ,
11251174 )
11261175 y , self .moe_layer = self ._apply_interleaved_scanned_layers (
1127- self .moe_layer , y , layer_args , layer_kwargs ,
1176+ self .moe_layer ,
1177+ y ,
1178+ layer_args ,
1179+ layer_kwargs ,
11281180 start_idx = 0 ,
11291181 end_idx = (cfg .num_decoder_layers - cfg .first_num_dense_layers ),
11301182 engram_indices = [e - cfg .first_num_dense_layers for e in cfg .engram_layers ],
@@ -1141,7 +1193,12 @@ def __call__(
11411193 if cfg .use_batch_split_schedule :
11421194 mock_params = self ._build_linen_params (self .moe_layer )
11431195 y = deepseek_batchsplit .scan_batch_split_layers (
1144- y , mock_params , decoder_positions , mesh = self .mesh , cfg = cfg , num_layers = num_moe ,
1196+ y ,
1197+ mock_params ,
1198+ decoder_positions ,
1199+ mesh = self .mesh ,
1200+ cfg = cfg ,
1201+ num_layers = num_moe ,
11451202 )
11461203 elif hasattr (self , "moe_layers_outside_pipeline" ):
11471204 num_moe_outside = (cfg .num_decoder_layers - cfg .first_num_dense_layers ) - cfg .pipeline_parallel_layers
@@ -1223,7 +1280,6 @@ def __call__(
12231280 decoder_input_tokens = decoder_input_tokens ,
12241281 )
12251282
1226-
12271283 else :
12281284 # Non-Pipeline Run
12291285 if cfg .scan_layers :
@@ -1265,12 +1321,9 @@ def __call__(
12651321 y ,
12661322 raw_weights ,
12671323 decoder_positions ,
1268- decoder_segment_ids ,
1269- model_mode = model_mode ,
12701324 mesh = self .mesh ,
1271- quant = self .quant ,
12721325 cfg = cfg ,
1273- policy = self . get_remat_policy () ,
1326+ num_layers = num_moe ,
12741327 )
12751328 else :
12761329 y , new_state = self ._apply_layers_sequentially (
0 commit comments