@@ -428,8 +428,23 @@ def pure_layer_fn(state_in, y_in):
428428
429429 return out
430430
431- def _apply_layers_sequentially (self , layers , x_in , * args , length : int , ** kwargs ):
432- """Runs the layer stack using nnx.scan."""
431+ def _apply_layers_sequentially (self , layers , x_in , * args , length : int , kv_caches_stacked = None , ** kwargs ):
432+ """Runs the layer stack using nnx.scan.
433+
434+ Args:
435+ layers: The stacked NNX module whose params are scanned over.
436+ x_in: The carry (hidden state) fed into the first layer.
437+ *args: Positional args broadcast to every layer call.
438+ length: Number of scan iterations (= number of layers).
439+ kv_caches_stacked: Optional pytree whose leaves have shape [num_layers, ...].
440+ When provided, the i-th slice is passed as `kv_cache=` to layer i and the
441+ updated caches are returned as a third element of the tuple.
442+ **kwargs: Keyword args forwarded to the layer (filtered by the layer signature).
443+
444+ Returns:
445+ (final_carry, updated_layers) when kv_caches_stacked is None.
446+ (final_carry, updated_layers, returned_kv_stacked) otherwise.
447+ """
433448 policy = self .get_remat_policy ()
434449 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (self .config )
435450 graphdef , params , state = nnx .split (
@@ -450,35 +465,83 @@ def _apply_layers_sequentially(self, layers, x_in, *args, length: int, **kwargs)
450465 # Filter kwargs to only include keys that exist in the layer's signature
451466 valid_kwargs = {k : v for k , v in kwargs .items () if k in sig .parameters or "kwargs" in sig .parameters }
452467
468+ use_kv = kv_caches_stacked is not None
469+
453470 def layer_fn (carry , scanned_vars ):
454471 # Unpack the sliced variables for THIS layer
455- current_params , current_state = scanned_vars
472+ if use_kv :
473+ current_params , current_state , kv_cache_layer = scanned_vars
474+ else :
475+ current_params , current_state = scanned_vars
476+ kv_cache_layer = None
456477
457478 if self .config .parameter_memory_host_offload :
458479 current_params = jax .tree .map (lambda x : jax .device_put (x , max_utils .device_space ()), current_params )
459480
460481 # Merge using the SLICED state
461482 layer = nnx .merge (graphdef , current_params , current_state )
462483
463- # Run the layer (Filter kwargs if using the solution from previous turn)
464- layer_out = layer (carry , * args , ** valid_kwargs )
484+ # Build call kwargs, injecting per-layer kv_cache when available
485+ call_kwargs = dict (valid_kwargs )
486+ if kv_cache_layer is not None :
487+ call_kwargs ["kv_cache" ] = kv_cache_layer
465488
466- new_carry = layer_out [0 ] if isinstance (layer_out , tuple ) else layer_out
489+ layer_out = layer (carry , * args , ** call_kwargs )
490+
491+ if isinstance (layer_out , tuple ):
492+ new_carry = layer_out [0 ]
493+ updated_kv = layer_out [1 ] if len (layer_out ) > 1 else None
494+ else :
495+ new_carry = layer_out
496+ updated_kv = None
467497
468498 # Extract the updated state to return it
469- # _, new_current_state = nnx.split(layer, nnx.Param, ...)
470499 new_current_state = nnx .state (layer )
500+
501+ if use_kv :
502+ return new_carry , (new_current_state , updated_kv )
471503 return new_carry , new_current_state
472504
473505 layer_fn = jax .checkpoint (layer_fn , policy = policy , prevent_cse = prevent_cse )
474506
475- final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
507+ if use_kv :
508+ # If kv_caches is provided (e.g., from vLLM), we CANNOT use jax.lax.scan
509+ # because scanning requires stacking the kv_caches list, which creates a copy
510+ # and breaks the in-place memory updates required by vLLM's PagedAttention.
511+ # Therefore, we must unroll the loop statically when kv_caches is provided.
512+
513+ # kv_caches_stacked is actually the original kv_caches list in this new flow
514+ kv_caches_list = kv_caches_stacked
515+
516+ current_carry = x_in
517+
518+ for i in range (length ):
519+ # Statically slice the parameters and state for this layer
520+ current_params = jax .tree .map (lambda x : x [i ], params )
521+ current_state = jax .tree .map (lambda x : x [i ], state )
522+
523+ # Call the layer
524+ current_carry , (new_current_state , updated_kv ) = layer_fn (
525+ current_carry , (current_params , current_state , kv_caches_list [i ])
526+ )
527+
528+ # Update the list in-place (mutates the list passed by reference)
529+ kv_caches_list [i ] = updated_kv
530+
531+ # We don't need to rebuild scanned_state or return it because during
532+ # inference with vLLM, parameters do not change and we don't need intermediates.
533+ return current_carry , layers , None
534+ else :
535+ final_carry , scanned_state = jax .lax .scan (layer_fn , x_in , (params , state ))
536+ returned_kv_stacked = None
476537
477538 if scan_axis != 0 :
478539 scanned_params , scanned_other = scanned_state .split (nnx .Param , ...)
479540 scanned_params = jax .tree .map (lambda x : jnp .moveaxis (x , 0 , scan_axis ), scanned_params )
480541 scanned_state = nnx .State .merge (scanned_params , scanned_other )
481542
543+ if use_kv :
544+ return final_carry , nnx .merge (graphdef , scanned_state ), returned_kv_stacked
482545 return final_carry , nnx .merge (graphdef , scanned_state )
483546
484547 def get_decoder_layers (self ):
@@ -1001,7 +1064,19 @@ def __call__(
10011064 )
10021065 else :
10031066 scan_length = int (cfg .num_decoder_layers / cfg .inhomogeneous_layer_cycle_interval )
1004- y , self .layers = self ._apply_layers_sequentially (self .layers , y , * layer_args , length = scan_length , ** layer_kwargs )
1067+ if kv_caches is not None :
1068+ # Pass the kv_caches list directly to avoid copying in jnp.stack,
1069+ # which breaks vLLM PagedAttention in-place memory updates.
1070+ # The _apply_layers_sequentially function will handle it by statically unrolling.
1071+ y , self .layers , returned_kv = self ._apply_layers_sequentially (
1072+ self .layers , y , * layer_args , length = scan_length ,
1073+ kv_caches_stacked = kv_caches , ** layer_kwargs
1074+ )
1075+ # kv_caches list is updated in-place inside _apply_layers_sequentially
1076+ else :
1077+ y , self .layers = self ._apply_layers_sequentially (
1078+ self .layers , y , * layer_args , length = scan_length , ** layer_kwargs
1079+ )
10051080 else :
10061081 prevent_cse = maxtext_utils .should_prevent_cse_in_remat (cfg )
10071082
0 commit comments