@@ -303,6 +303,7 @@ def layer_fn(carry, scanned_vars):
303303 layer = nnx .merge (graphdef , current_params , current_state )
304304 layer_out = layer (carry , decoder_segment_ids , decoder_positions , deterministic , model_mode , ** kwargs )
305305 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
306+ nnx .pop (layer , nnx .Intermediate )
306307 return new_carry , nnx .state (layer )
307308
308309 final_carry , scanned_state = jax .lax .scan (layer_fn , inputs , (params , state ))
@@ -534,6 +535,8 @@ def _create_scanned_layers(
534535 self , decoder_layer_class , length : int , metadata_axis_name : str , rngs : nnx .Rngs , ** layer_kwargs
535536 ):
536537 """Creates a VMapped stack of layers, forcing parameter init for Compact modules."""
538+ if length == 0 :
539+ return nnx .List ([])
537540
538541 def create_layer_fn (rng ):
539542 return decoder_layer_class (
@@ -566,13 +569,17 @@ def pure_layer_fn(state_in, y_in):
566569 out = merged_layer (y_in , ** kwargs )
567570 return out , nnx .state (merged_layer )
568571
569- checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
570- out , new_state = checkpointed_fn (state , y )
572+ if not self ._uses_linen_fp8_ops ():
573+ pure_layer_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
574+ out , new_state = pure_layer_fn (state , y )
571575 nnx .update (layer , new_state )
572576 return out
573577
574578 def _apply_layers_sequentially (self , layers , x_in , * args , length : int , ** kwargs ):
575579 """Runs the layer stack using nnx.scan."""
580+ if length == 0 :
581+ _ , empty_state = nnx .split (layers )
582+ return x_in , empty_state
576583 policy = self .get_remat_policy ()
577584 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config )
578585 graphdef , params , state = nnx .split (
@@ -608,7 +615,25 @@ def layer_fn(carry, scanned_vars):
608615 # Run the layer (Filter kwargs if using the solution from previous turn)
609616 layer_out = layer (carry , * args , ** valid_kwargs )
610617 new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
611- return new_carry , nnx .state (layer )
618+ nnx .pop (layer , nnx .Intermediate )
619+ new_current_state = nnx .state (layer )
620+ return new_carry , new_current_state
621+
622+ if self ._uses_linen_fp8_ops ():
623+ # jax.lax.scan is incompatible with Linen fp8 ops: put_variable in setup() stores
624+ # scan-level tracers as Python attributes on the Linen module, causing a tracer leak
625+ # across the scan boundary. Fall back to a Python loop instead.
626+ x = x_in
627+ for i in range (length ):
628+ params_i = jax .tree .map (lambda p , _i = i : p [_i ], params )
629+ state_i = jax .tree .map (lambda s , _i = i : s [_i ], state )
630+ layer = nnx .merge (graphdef , params_i , state_i )
631+ layer_out = layer (x , * args , ** valid_kwargs )
632+ x = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
633+ nnx .pop (layer , nnx .Intermediate )
634+ if scan_axis != 0 :
635+ params = jax .tree .map (lambda p : jnp .moveaxis (p , 0 , scan_axis ), params )
636+ return x , nnx .State .merge (params , state )
612637
613638 layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
614639 final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
@@ -672,7 +697,8 @@ def get_chunk(pytree, start, end):
672697 layer_out = layer (y , * layer_args , ** valid_kwargs )
673698 y = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
674699
675- _ , new_eng_mutables = nnx .split (layer , nnx .Param , ...)
700+ nnx .pop (layer , nnx .Intermediate )
701+ _ , _ , new_eng_mutables = nnx .split (layer , nnx .Param , ...)
676702 new_eng_mutables = jax .tree .map (lambda x : jnp .expand_dims (x , axis = 0 ), new_eng_mutables )
677703 updated_mutables_chunks .append (new_eng_mutables )
678704 current_idx += 1
@@ -698,10 +724,12 @@ def layer_fn(carry, scanned_vars):
698724 l = nnx .merge (graphdef , curr_p , curr_m )
699725 l_out = l (carry , * layer_args , ** valid_kwargs )
700726 n_carry = l_out [0 ] if isinstance (l_out , tuple ) else l_out
701- _ , n_mut = nnx .split (l , nnx .Param , ...)
727+ nnx .pop (l , nnx .Intermediate )
728+ _ , _ , n_mut = nnx .split (l , nnx .Param , ...)
702729 return n_carry , n_mut
703730
704- layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
731+ if not self ._uses_linen_fp8_ops ():
732+ layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
705733 y , new_chunk_mutables = jax .lax .scan (layer_fn , y , (chunk_params , chunk_mutables ))
706734 updated_mutables_chunks .append (new_chunk_mutables )
707735 current_idx = next_boundary
@@ -742,7 +770,11 @@ def pure_layer_fn(graphdef, state_in, y_in, kv_in, dynamic_kwargs):
742770 out_y , out_kv = merged_layer (y_in , * layer_args , kv_cache = kv_in , ** dynamic_kwargs )
743771 return out_y , out_kv , nnx .state (merged_layer )
744772
745- checkpointed_fn = jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
773+ checkpointed_fn = (
774+ pure_layer_fn
775+ if self ._uses_linen_fp8_ops ()
776+ else jax .checkpoint (pure_layer_fn , policy = policy , prevent_cse = prevent_cse )
777+ )
746778
747779 for lyr in range (num_layers ):
748780 attr_name = f"{ base_name } _{ lyr } "
@@ -921,6 +953,10 @@ def get_remat_policy(self):
921953 assert cfg .remat_policy == "full" , "Remat policy needs to be on list of remat policies"
922954 return policy
923955
956+ def _uses_linen_fp8_ops (self ) -> bool :
957+ """Returns True if the quantization mode uses Linen fp8 ops incompatible with jax.checkpoint."""
958+ return self .config .quantization in ("fp8_gpu" , "fp8_nanoo" )
959+
924960 def get_norm_layer (self , num_features : int , rngs : nnx .Rngs ):
925961 """Helper to retrieve the correct normalization layer class based on config, partially applied with common arguments."""
926962 if self .config .decoder_block in (
@@ -1072,10 +1108,18 @@ def __call__(
10721108 audio_embeddings : None | jnp .ndarray = None ,
10731109 audio_masks : None | jnp .ndarray = None ,
10741110 deepstack_visual_embeds : None | list [jnp .ndarray ] = None ,
1111+ multimodal_input = None ,
10751112 ):
10761113 cfg = self .config
10771114 assert decoder_input_tokens .ndim == 2 # [batch, len]
10781115
1116+ if multimodal_input is not None :
1117+ image_embeddings = multimodal_input .image_embeddings
1118+ bidirectional_mask = multimodal_input .bidirectional_mask
1119+ image_masks = multimodal_input .image_masks
1120+ audio_embeddings = multimodal_input .audio_embeddings
1121+ audio_masks = multimodal_input .audio_masks
1122+
10791123 # [batch, length] -> [batch, length, emb_dim]
10801124 y = self ._apply_embedding (
10811125 shared_embedding ,
@@ -1223,7 +1267,6 @@ def __call__(
12231267 decoder_input_tokens = decoder_input_tokens ,
12241268 )
12251269
1226-
12271270 else :
12281271 # Non-Pipeline Run
12291272 if cfg .scan_layers :
0 commit comments