@@ -428,8 +428,23 @@ def pure_layer_fn(state_in, y_in):
428428
429429 return out
430430
431- def _apply_layers_sequentially (self , layers , x_in , * args , length : int , ** kwargs ):
432- """Runs the layer stack using nnx.scan."""
431+ def _apply_layers_sequentially (self , layers , x_in , * args , length : int , kv_caches_stacked = None , ** kwargs ):
432+ """Runs the layer stack using nnx.scan.
433+
434+ Args:
435+ layers: The stacked NNX module whose params are scanned over.
436+ x_in: The carry (hidden state) fed into the first layer.
437+ *args: Positional args broadcast to every layer call.
438+ length: Number of scan iterations (= number of layers).
439+ kv_caches_stacked: Optional pytree whose leaves have shape [num_layers, ...].
440+ When provided, the i-th slice is passed as `kv_cache=` to layer i and the
441+ updated caches are returned as a third element of the tuple.
442+ **kwargs: Keyword args forwarded to the layer (filtered by the layer signature).
443+
444+ Returns:
445+ (final_carry, updated_layers) when kv_caches_stacked is None.
446+ (final_carry, updated_layers, returned_kv_stacked) otherwise.
447+ """
433448 policy = self .get_remat_policy ()
434449 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config )
435450 graphdef , params , state = nnx .split (
@@ -450,36 +465,80 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
450465 # Filter kwargs to only include keys that exist in the layer's signature
451466 valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
452467
468+ use_kv = kv_caches_stacked is not None
469+
453470 def layer_fn (carry , scanned_vars ):
454471 # Unpack the sliced variables for THIS layer
455- current_params , current_state = scanned_vars
472+ if use_kv :
473+ current_params , current_state , kv_cache_layer = scanned_vars
474+ else :
475+ current_params , current_state = scanned_vars
476+ kv_cache_layer = None
456477
457478 if self .config .parameter_memory_host_offload :
458479 current_params = jax .tree .map (lambda x : jax .device_put (x , max_utils .device_space ()), current_params )
459480
460481 # Merge using the SLICED state
461482 layer = nnx .merge (graphdef , current_params , current_state )
462483
463- # Run the layer (Filter kwargs if using the solution from previous turn)
464- layer_out = layer (carry , * args , ** valid_kwargs )
484+ # Build call kwargs, injecting per-layer kv_cache when available
485+ call_kwargs = dict (valid_kwargs )
486+ if kv_cache_layer is not None :
487+ call_kwargs ["kv_cache" ] = kv_cache_layer
488+
489+ layer_out = layer (carry , * args , ** call_kwargs )
465490
466- new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
491+ if isinstance (layer_out , tuple ):
492+ new_carry = layer_out [0 ]
493+ updated_kv = layer_out [1 ] if len (layer_out ) > 1 else None
494+ else :
495+ new_carry = layer_out
496+ updated_kv = None
467497
468498 # Extract the updated state to return it
469- # _, new_current_state = nnx.split(layer, nnx.Param, ...)
470499 new_current_state = nnx .state (layer )
500+
501+ if use_kv :
502+ return new_carry , (new_current_state , updated_kv )
471503 return new_carry , new_current_state
472504
473505 layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
474506
475- final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
507+ if use_kv :
508+ # If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan
509+ # because scanning requires stacking the kv_caches list, which creates a copy
510+ # and breaks the in-place memory updates required by vLLM's PagedAttention.
511+ # Therefore, we must unroll the loop statically when kv_caches is provided.
512+
513+ # kv_caches_stacked is actually the original kv_caches list in this new flow
514+ kv_caches_list = kv_caches_stacked
515+
516+ current_carry = x_in
517+
518+ for i in range (length ):
519+ # Statically slice the parameters and state for this layer
520+ current_params = jax .tree .map (lambda x , i = i : x [i ], params )
521+ current_state = jax .tree .map (lambda x , i = i : x [i ], state )
522+
523+ # Call the layer
524+ current_carry , (_ , updated_kv ) = layer_fn (current_carry , (current_params , current_state , kv_caches_list [i ]))
525+
526+ # Update the list in-place (mutates the list passed by reference)
527+ kv_caches_list [i ] = updated_kv
528+
529+ # We don't need to rebuild scanned_state or return it because during
530+ # inference with vLLM, parameters do not change and we don't need intermediates.
531+ return current_carry , layers , None
532+ else :
533+ final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
534+ returned_kv_stacked = None
476535
477536 if scan_axis != 0 :
478537 scanned_params , scanned_other = scanned_state .split (nnx .Param , ...)
479538 scanned_params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), scanned_params )
480539 scanned_state = nnx .State .merge (scanned_params , scanned_other )
481540
482- return final_carry , nnx .merge (graphdef , scanned_state )
541+ return final_carry , nnx .merge (graphdef , scanned_state ), returned_kv_stacked if use_kv else None
483542
484543 def get_decoder_layers (self ):
485544 """Retrieves decoder layer classes based on config using a dictionary lookup."""
@@ -859,7 +918,7 @@ def _apply_scanned_chunk(self, y, current_idx, next_boundary, layer_stack, *args
859918 chunk_stack = nnx .merge (graphdef , chunk_state )
860919
861920 # Apply sequentially
862- y , chunk_stack = self ._apply_layers_sequentially (
921+ y , chunk_stack , _ = self ._apply_layers_sequentially (
863922 chunk_stack , y , * args , length = scan_length , ** kwargs .get ("layer_kwargs" , {})
864923 )
865924
@@ -966,7 +1025,7 @@ def __call__(
9661025 ** common_kwargs ,
9671026 )
9681027 else :
969- y , self .dense_layers = self ._apply_layers_sequentially (
1028+ y , self .dense_layers , _ = self ._apply_layers_sequentially (
9701029 self .dense_layers , y , * layer_args , length = cfg .first_num_dense_layers , ** layer_kwargs
9711030 )
9721031
@@ -984,7 +1043,7 @@ def __call__(
9841043 num_layers = num_moe ,
9851044 )
9861045 else :
987- y , self .moe_layer = self ._apply_layers_sequentially (
1046+ y , self .moe_layer , _ = self ._apply_layers_sequentially (
9881047 self .moe_layer , y , * layer_args , length = num_moe , ** layer_kwargs
9891048 )
9901049 elif self .is_gemma3 :
@@ -1001,7 +1060,18 @@ def __call__(
10011060 )
10021061 else :
10031062 scan_length = int (cfg .num_decoder_layers / cfg .inhomogeneous_layer_cycle_interval )
1004- y , self .layers = self ._apply_layers_sequentially (self .layers , y , * layer_args , length = scan_length , ** layer_kwargs )
1063+ if kv_caches is not None :
1064+ # Pass the kv_caches list directly to avoid copying in jnp.stack,
1065+ # which breaks vLLM PagedAttention in-place memory updates.
1066+ # The _apply_layers_sequentially function will handle it by statically unrolling.
1067+ y , self .layers , _ = self ._apply_layers_sequentially (
1068+ self .layers , y , * layer_args , length = scan_length , kv_caches_stacked = kv_caches , ** layer_kwargs
1069+ )
1070+ # kv_caches list is updated in-place inside _apply_layers_sequentially
1071+ else :
1072+ y , self .layers , _ = self ._apply_layers_sequentially (
1073+ self .layers , y , * layer_args , length = scan_length , ** layer_kwargs
1074+ )
10051075 else :
10061076 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (cfg )
10071077
@@ -1085,7 +1155,7 @@ def _apply_gemma3_scanned_blocks(
10851155
10861156 # Apply the main scan over the full blocks
10871157 if scan_length > 0 :
1088- y , self .layers = self ._apply_layers_sequentially (self .layers , y , * layer_args , length = scan_length , ** layer_kwargs )
1158+ y , self .layers , _ = self ._apply_layers_sequentially (self .layers , y , * layer_args , length = scan_length , ** layer_kwargs )
10891159
10901160 # Apply any remaining layers that did not fit into a full scanned block
10911161 num_remaining_layers = cfg .num_decoder_layers % attention_pattern_length
0 commit comments