@@ -981,16 +981,54 @@ def __call__(
981981 "nope_layer_interval" : self .config .nope_layer_interval ,
982982 "interleave_moe_layer_step" : self .config .interleave_moe_layer_step ,
983983 }
984- y , _ = self .scan_decoder_layers (
985- cfg ,
986- RemattedBlockLayer ,
987- scan_length ,
988- "layers" ,
989- mesh ,
990- in_axes_tuple = (nn .broadcast ,) * len (broadcast_args ),
991- model_mode = model_mode ,
992- ** layer_kwargs ,
993- )(y , * broadcast_args )
984+ # Update broadcast_args and in_axes_tuple for vLLM RPA
985+ in_axes_tuple = (nn .broadcast ,) * len (broadcast_args )
986+ current_broadcast_args = list (broadcast_args )
987+ current_in_axes_tuple = list (in_axes_tuple )
988+
989+ current_broadcast_args .append (attention_metadata )
990+ current_in_axes_tuple .append (nn .broadcast )
991+
992+ if kv_caches is not None :
993+ # Stack kv_caches for scan: [num_layers, ...]
994+ stacked_kv_cache = jnp .stack (kv_caches , axis = 0 )
995+
996+ # We pass (y, stacked_kv_cache, 0) as the carry
997+ carry = (y , stacked_kv_cache , 0 )
998+
999+ # We don't pass kv_cache as a scanned argument anymore
1000+
1001+ final_carry , _ = self .scan_decoder_layers (
1002+ cfg ,
1003+ RemattedBlockLayer ,
1004+ scan_length ,
1005+ "layers" ,
1006+ mesh ,
1007+ in_axes_tuple = tuple (current_in_axes_tuple ),
1008+ model_mode = model_mode ,
1009+ ** layer_kwargs ,
1010+ )(carry , * current_broadcast_args )
1011+
1012+ y , returned_kv_cache , _ = final_carry
1013+
1014+ # Update the list of KV caches from the scanned results
1015+ for i in range (cfg .num_decoder_layers ):
1016+ kv_caches [i ] = returned_kv_cache [i ]
1017+ else :
1018+ # Fallback to old behavior if kv_caches is None (not vLLM RPA)
1019+ current_broadcast_args .append (None )
1020+ current_in_axes_tuple .append (nn .broadcast )
1021+
1022+ y , _ = self .scan_decoder_layers (
1023+ cfg ,
1024+ RemattedBlockLayer ,
1025+ scan_length ,
1026+ "layers" ,
1027+ mesh ,
1028+ in_axes_tuple = tuple (current_in_axes_tuple ),
1029+ model_mode = model_mode ,
1030+ ** layer_kwargs ,
1031+ )(y , * current_broadcast_args )
9941032 else :
9951033 if cfg .decoder_block == DecoderBlockType .DEEPSEEK :
9961034 assert len (RemattedBlockLayers ) == 2 , "Unscanned layers must have a length of 2 using deepseek."
0 commit comments