@@ -792,7 +792,13 @@ def __call__(
792792 decoder_positions ,
793793 deterministic ,
794794 model_mode ,
795+ previous_chunk ,
796+ page_state ,
797+ slot ,
795798 )
799+ in_axes_tuple = (nn .broadcast ,) * len (broadcast_args )
800+ # Pipeline module only accepts (segment_ids, positions, deterministic, model_mode)
801+ pipeline_broadcast_args = broadcast_args [:4 ]
796802 if cfg .using_pipeline_parallelism :
797803 if cfg .pipeline_fsdp_ag_once :
798804 logical_partition_spec = self .pipeline_module .get_weight_sharding (
@@ -828,9 +834,9 @@ def __call__(
828834 in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
829835 model_mode = model_mode ,
830836 )(y , * broadcast_args )
831- y = self .pipeline_module (y , * broadcast_args , logical_partition_spec = logical_partition_spec )
837+ y = self .pipeline_module (y , * pipeline_broadcast_args , logical_partition_spec = logical_partition_spec )
832838 else : # Not DeepSeek
833- y = self .pipeline_module (y , * broadcast_args , logical_partition_spec = logical_partition_spec )
839+ y = self .pipeline_module (y , * pipeline_broadcast_args , logical_partition_spec = logical_partition_spec )
834840 remaining_layers = self .config .num_decoder_layers - self .config .pipeline_parallel_layers
835841 if remaining_layers > 0 :
836842 logical_axis_rules_pp_as_dp = sharding .logical_axis_rules_pp_act_as_dp (self .config .logical_axis_rules )
@@ -847,26 +853,12 @@ def __call__(
847853 else :
848854 if cfg .scan_layers :
849855 if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
850- assert len (RemattedBlockLayers ) == 2 , "Scanned layers must have a length of 2 using deepseek."
851- layer_call_kwargs = {
852- "page_state" : page_state ,
853- "previous_chunk" : previous_chunk ,
854- "slot" : slot ,
855- }
856856 dense_layer = RemattedBlockLayers [0 ]
857857 moe_layer = RemattedBlockLayers [1 ]
858858 if cfg .engram_layers :
859- original_dense_call = dense_layer .__call__
860- original_moe_call = moe_layer .__call__
861- dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
862- moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
863-
864859 common_kwargs = {
865860 "dense_layer" : dense_layer ,
866861 "moe_layer" : moe_layer ,
867- "original_dense_call" : original_dense_call ,
868- "original_moe_call" : original_moe_call ,
869- "layer_call_kwargs" : layer_call_kwargs ,
870862 "decoder_segment_ids" : decoder_segment_ids ,
871863 "decoder_positions" : decoder_positions ,
872864 "deterministic" : deterministic ,
@@ -895,7 +887,6 @@ def __call__(
895887 ** common_kwargs ,
896888 )
897889 else :
898- dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
899890 y , _ = self .scan_decoder_layers (
900891 cfg ,
901892 dense_layer ,
@@ -905,7 +896,6 @@ def __call__(
905896 in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
906897 model_mode = model_mode ,
907898 )(y , * broadcast_args )
908- moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
909899 num_moe_layers = cfg .num_decoder_layers - cfg .first_num_dense_layers
910900
911901 # If batch-split schedule is used and initialization is complete,
@@ -954,16 +944,38 @@ def __call__(
954944 "nope_layer_interval" : self .config .nope_layer_interval ,
955945 "interleave_moe_layer_step" : self .config .interleave_moe_layer_step ,
956946 }
957- y , _ = self .scan_decoder_layers (
947+
948+ # Update broadcast_args and in_axes_tuple for vLLM RPA
949+ current_broadcast_args = list (broadcast_args )
950+ current_in_axes_tuple = list (in_axes_tuple )
951+
952+ if kv_caches is not None :
953+ # Stack kv_caches for scan: [num_layers, ...]
954+ stacked_kv_cache = jnp .stack (kv_caches , axis = 0 )
955+ current_broadcast_args .append (stacked_kv_cache )
956+ current_in_axes_tuple .append (0 ) # Scan over the layer dimension
957+ else :
958+ current_broadcast_args .append (None )
959+ current_in_axes_tuple .append (nn .broadcast )
960+
961+ current_broadcast_args .append (attention_metadata )
962+ current_in_axes_tuple .append (nn .broadcast )
963+
964+ y , returned_kv_cache = self .scan_decoder_layers (
958965 cfg ,
959966 RemattedBlockLayer ,
960967 scan_length ,
961968 "layers" ,
962969 mesh ,
963- in_axes_tuple = ( nn . broadcast ,) * len ( broadcast_args ),
970+ in_axes_tuple = tuple ( current_in_axes_tuple ),
964971 model_mode = model_mode ,
965972 ** layer_kwargs ,
966- )(y , * broadcast_args )
973+ )(y , * current_broadcast_args )
974+
975+ if kv_caches is not None and returned_kv_cache is not None :
976+ # Update the list of KV caches from the scanned results
977+ for i , cache in enumerate (returned_kv_cache ):
978+ kv_caches [i ] = cache
967979 else :
968980 if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
969981 assert len (RemattedBlockLayers ) == 2 , "Unscanned layers must have a length of 2 using deepseek."
@@ -1173,10 +1185,8 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
11731185 """Applies a single, unscanned Engram layer."""
11741186 layer = kwargs ["dense_layer" ] if layer_type == "dense" else kwargs ["moe_layer" ]
11751187 layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1176- original_call = kwargs ["original_dense_call" ] if layer_type == "dense" else kwargs ["original_moe_call" ]
1177- layer_call_kwargs = kwargs ["layer_call_kwargs" ]
1188+ broadcast_args = kwargs ["broadcast_args" ]
11781189
1179- layer .__call__ = original_call
11801190 y , _ = layer (
11811191 config = self .config ,
11821192 mesh = self .mesh ,
@@ -1186,14 +1196,9 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
11861196 layer_idx = current_idx ,
11871197 )(
11881198 y ,
1189- kwargs ["decoder_segment_ids" ],
1190- kwargs ["decoder_positions" ],
1191- kwargs ["deterministic" ],
1192- kwargs ["model_mode" ],
1199+ * broadcast_args ,
11931200 decoder_input_tokens = kwargs ["decoder_input_tokens" ],
1194- ** layer_call_kwargs ,
11951201 )
1196- layer .__call__ = functools .partial (original_call , ** layer_call_kwargs )
11971202 return y
11981203
11991204 def _apply_scanned_chunk (self , y , current_idx , next_boundary , layer_type , ** kwargs ):
0 commit comments