@@ -792,7 +792,11 @@ def __call__(
792792 decoder_positions ,
793793 deterministic ,
794794 model_mode ,
795+ previous_chunk ,
796+ page_state ,
797+ slot ,
795798 )
799+ in_axes_tuple = (nn .broadcast ,) * len (broadcast_args )
796800 if cfg .using_pipeline_parallelism :
797801 if cfg .pipeline_fsdp_ag_once :
798802 logical_partition_spec = self .pipeline_module .get_weight_sharding (
@@ -847,26 +851,12 @@ def __call__(
847851 else :
848852 if cfg .scan_layers :
849853 if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
850- assert len (RemattedBlockLayers ) == 2 , "Scanned layers must have a length of 2 using deepseek."
851- layer_call_kwargs = {
852- "page_state" : page_state ,
853- "previous_chunk" : previous_chunk ,
854- "slot" : slot ,
855- }
856854 dense_layer = RemattedBlockLayers [0 ]
857855 moe_layer = RemattedBlockLayers [1 ]
858856 if cfg .engram_layers :
859- original_dense_call = dense_layer .__call__
860- original_moe_call = moe_layer .__call__
861- dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
862- moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
863-
864857 common_kwargs = {
865858 "dense_layer" : dense_layer ,
866859 "moe_layer" : moe_layer ,
867- "original_dense_call" : original_dense_call ,
868- "original_moe_call" : original_moe_call ,
869- "layer_call_kwargs" : layer_call_kwargs ,
870860 "decoder_segment_ids" : decoder_segment_ids ,
871861 "decoder_positions" : decoder_positions ,
872862 "deterministic" : deterministic ,
@@ -895,7 +885,6 @@ def __call__(
895885 ** common_kwargs ,
896886 )
897887 else :
898- dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
899888 y , _ = self .scan_decoder_layers (
900889 cfg ,
901890 dense_layer ,
@@ -905,7 +894,6 @@ def __call__(
905894 in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
906895 model_mode = model_mode ,
907896 )(y , * broadcast_args )
908- moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
909897 num_moe_layers = cfg .num_decoder_layers - cfg .first_num_dense_layers
910898
911899 # If batch-split schedule is used and initialization is complete,
@@ -954,16 +942,38 @@ def __call__(
954942 "nope_layer_interval" : self .config .nope_layer_interval ,
955943 "interleave_moe_layer_step" : self .config .interleave_moe_layer_step ,
956944 }
957- y , _ = self .scan_decoder_layers (
945+
946+ # Update broadcast_args and in_axes_tuple for vLLM RPA
947+ current_broadcast_args = list (broadcast_args )
948+ current_in_axes_tuple = list (in_axes_tuple )
949+
950+ if kv_caches is not None :
951+ # Stack kv_caches for scan: [num_layers, ...]
952+ stacked_kv_cache = jnp .stack (kv_caches , axis = 0 )
953+ current_broadcast_args .append (stacked_kv_cache )
954+ current_in_axes_tuple .append (0 ) # Scan over the layer dimension
955+ else :
956+ current_broadcast_args .append (None )
957+ current_in_axes_tuple .append (nn .broadcast )
958+
959+ current_broadcast_args .append (attention_metadata )
960+ current_in_axes_tuple .append (nn .broadcast )
961+
962+ y , returned_kv_cache = self .scan_decoder_layers (
958963 cfg ,
959964 RemattedBlockLayer ,
960965 scan_length ,
961966 "layers" ,
962967 mesh ,
963- in_axes_tuple = ( nn . broadcast ,) * len ( broadcast_args ),
968+ in_axes_tuple = tuple ( current_in_axes_tuple ),
964969 model_mode = model_mode ,
965970 ** layer_kwargs ,
966- )(y , * broadcast_args )
971+ )(y , * current_broadcast_args )
972+
973+ if kv_caches is not None and returned_kv_cache is not None :
974+ # Update the list of KV caches from the scanned results
975+ for i , cache in enumerate (returned_kv_cache ):
976+ kv_caches [i ] = cache
967977 else :
968978 if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
969979 assert len (RemattedBlockLayers ) == 2 , "Unscanned layers must have a length of 2 using deepseek."
@@ -1173,10 +1183,8 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
11731183 """Applies a single, unscanned Engram layer."""
11741184 layer = kwargs ["dense_layer" ] if layer_type == "dense" else kwargs ["moe_layer" ]
11751185 layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1176- original_call = kwargs ["original_dense_call" ] if layer_type == "dense" else kwargs ["original_moe_call" ]
1177- layer_call_kwargs = kwargs ["layer_call_kwargs" ]
1186+ broadcast_args = kwargs ["broadcast_args" ]
11781187
1179- layer .__call__ = original_call
11801188 y , _ = layer (
11811189 config = self .config ,
11821190 mesh = self .mesh ,
@@ -1186,14 +1194,9 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
11861194 layer_idx = current_idx ,
11871195 )(
11881196 y ,
1189- kwargs ["decoder_segment_ids" ],
1190- kwargs ["decoder_positions" ],
1191- kwargs ["deterministic" ],
1192- kwargs ["model_mode" ],
1197+ * broadcast_args ,
11931198 decoder_input_tokens = kwargs ["decoder_input_tokens" ],
1194- ** layer_call_kwargs ,
11951199 )
1196- layer .__call__ = functools .partial (original_call , ** layer_call_kwargs )
11971200 return y
11981201
11991202 def _apply_scanned_chunk (self , y , current_idx , next_boundary , layer_type , ** kwargs ):
0 commit comments