@@ -795,7 +795,13 @@ def __call__(
795795 decoder_positions ,
796796 deterministic ,
797797 model_mode ,
798+ previous_chunk ,
799+ page_state ,
800+ slot ,
798801 )
802+ in_axes_tuple = (nn .broadcast ,) * len (broadcast_args )
803+ # Pipeline module only accepts (segment_ids, positions, deterministic, model_mode)
804+ pipeline_broadcast_args = broadcast_args [:4 ]
799805 if cfg .using_pipeline_parallelism :
800806 logical_partition_spec = (
801807 self .pipeline_module .get_weight_sharding (y , decoder_segment_ids , decoder_positions , deterministic , model_mode )
@@ -830,9 +836,9 @@ def __call__(
830836 in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
831837 model_mode = model_mode ,
832838 )(y , * broadcast_args )
833- y = self .pipeline_module (y , * broadcast_args , logical_partition_spec = logical_partition_spec )
839+ y = self .pipeline_module (y , * pipeline_broadcast_args , logical_partition_spec = logical_partition_spec )
834840 else : # Not DeepSeek
835- y = self .pipeline_module (y , * broadcast_args , logical_partition_spec = logical_partition_spec )
841+ y = self .pipeline_module (y , * pipeline_broadcast_args , logical_partition_spec = logical_partition_spec )
836842 remaining_layers = self .config .num_decoder_layers - self .config .pipeline_parallel_layers
837843 if remaining_layers > 0 :
838844 logical_axis_rules_pp_as_dp = sharding .logical_axis_rules_pp_act_as_dp (self .config .logical_axis_rules )
@@ -849,26 +855,12 @@ def __call__(
849855 else :
850856 if cfg .scan_layers :
851857 if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
852- assert len (RemattedBlockLayers ) == 2 , "Scanned layers must have a length of 2 using deepseek."
853- layer_call_kwargs = {
854- "page_state" : page_state ,
855- "previous_chunk" : previous_chunk ,
856- "slot" : slot ,
857- }
858858 dense_layer = RemattedBlockLayers [0 ]
859859 moe_layer = RemattedBlockLayers [1 ]
860860 if cfg .engram_layers :
861- original_dense_call = dense_layer .__call__
862- original_moe_call = moe_layer .__call__
863- dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
864- moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
865-
866861 common_kwargs = {
867862 "dense_layer" : dense_layer ,
868863 "moe_layer" : moe_layer ,
869- "original_dense_call" : original_dense_call ,
870- "original_moe_call" : original_moe_call ,
871- "layer_call_kwargs" : layer_call_kwargs ,
872864 "decoder_segment_ids" : decoder_segment_ids ,
873865 "decoder_positions" : decoder_positions ,
874866 "deterministic" : deterministic ,
@@ -897,7 +889,6 @@ def __call__(
897889 ** common_kwargs ,
898890 )
899891 else :
900- dense_layer .__call__ = functools .partial (dense_layer .__call__ , ** layer_call_kwargs )
901892 y , _ = self .scan_decoder_layers (
902893 cfg ,
903894 dense_layer ,
@@ -907,7 +898,6 @@ def __call__(
907898 in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
908899 model_mode = model_mode ,
909900 )(y , * broadcast_args )
910- moe_layer .__call__ = functools .partial (moe_layer .__call__ , ** layer_call_kwargs )
911901 num_moe_layers = cfg .num_decoder_layers - cfg .first_num_dense_layers
912902
913903 # If batch-split schedule is used and initialization is complete,
@@ -981,16 +971,38 @@ def __call__(
981971 "nope_layer_interval" : self .config .nope_layer_interval ,
982972 "interleave_moe_layer_step" : self .config .interleave_moe_layer_step ,
983973 }
984- y , _ = self .scan_decoder_layers (
974+
975+ # Update broadcast_args and in_axes_tuple for vLLM RPA
976+ current_broadcast_args = list (broadcast_args )
977+ current_in_axes_tuple = list (in_axes_tuple )
978+
979+ if kv_caches is not None :
980+ # Stack kv_caches for scan: [num_layers, ...]
981+ stacked_kv_cache = jnp .stack (kv_caches , axis = 0 )
982+ current_broadcast_args .append (stacked_kv_cache )
983+ current_in_axes_tuple .append (0 ) # Scan over the layer dimension
984+ else :
985+ current_broadcast_args .append (None )
986+ current_in_axes_tuple .append (nn .broadcast )
987+
988+ current_broadcast_args .append (attention_metadata )
989+ current_in_axes_tuple .append (nn .broadcast )
990+
991+ y , returned_kv_cache = self .scan_decoder_layers (
985992 cfg ,
986993 RemattedBlockLayer ,
987994 scan_length ,
988995 "layers" ,
989996 mesh ,
990- in_axes_tuple = ( nn . broadcast ,) * len ( broadcast_args ),
997+ in_axes_tuple = tuple ( current_in_axes_tuple ),
991998 model_mode = model_mode ,
992999 ** layer_kwargs ,
993- )(y , * broadcast_args )
1000+ )(y , * current_broadcast_args )
1001+
1002+ if kv_caches is not None and returned_kv_cache is not None :
1003+ # Update the list of KV caches from the scanned results
1004+ for i , cache in enumerate (returned_kv_cache ):
1005+ kv_caches [i ] = cache
9941006 else :
9951007 if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
9961008 assert len (RemattedBlockLayers ) == 2 , "Unscanned layers must have a length of 2 using deepseek."
@@ -1295,10 +1307,8 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
12951307 """Applies a single, unscanned Engram layer."""
12961308 layer = kwargs ["dense_layer" ] if layer_type == "dense" else kwargs ["moe_layer" ]
12971309 layer_prefix = "dense_layers" if layer_type == "dense" else "moe_layers"
1298- original_call = kwargs ["original_dense_call" ] if layer_type == "dense" else kwargs ["original_moe_call" ]
1299- layer_call_kwargs = kwargs ["layer_call_kwargs" ]
1310+ broadcast_args = kwargs ["broadcast_args" ]
13001311
1301- layer .__call__ = original_call
13021312 y , _ = layer (
13031313 config = self .config ,
13041314 mesh = self .mesh ,
@@ -1308,14 +1318,9 @@ def _apply_single_engram_layer(self, y, current_idx, layer_type, **kwargs):
13081318 layer_idx = current_idx ,
13091319 )(
13101320 y ,
1311- kwargs ["decoder_segment_ids" ],
1312- kwargs ["decoder_positions" ],
1313- kwargs ["deterministic" ],
1314- kwargs ["model_mode" ],
1321+ * broadcast_args ,
13151322 decoder_input_tokens = kwargs ["decoder_input_tokens" ],
1316- ** layer_call_kwargs ,
13171323 )
1318- layer .__call__ = functools .partial (original_call , ** layer_call_kwargs )
13191324 return y
13201325
13211326 def _apply_scanned_chunk (self , y , current_idx , next_boundary , layer_type , ** kwargs ):
0 commit comments