@@ -981,7 +981,7 @@ def forward_serve_vllm(
981981 value : Array ,
982982 rpa_kv_cache : list [Array ] | None = None ,
983983 rpa_metadata : dict [str , Any ] | None = None ,
984- ) -> tuple [list [ Array ], Array ]:
984+ ) -> tuple [Array , list [ Array ] ]:
985985 """Forward function for vLLM serving with RPA attention."""
986986 try :
987987 # pylint: disable=import-outside-toplevel
@@ -998,7 +998,7 @@ def forward_serve_vllm(
998998
999999 if rpa_kv_cache is None or rpa_metadata is None :
10001000 # Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
1001- return [], query
1001+ return query , []
10021002
10031003 if self .config .sliding_window_size > 0 :
10041004 attention_chunk_size = self .config .sliding_window_size
@@ -1027,7 +1027,7 @@ def forward_serve_vllm(
10271027 k_scale ,
10281028 v_scale ,
10291029 )
1030- return kv_cache , output
1030+ return output , kv_cache
10311031
10321032 def __call__ (
10331033 self ,
@@ -1170,7 +1170,7 @@ def __call__(
11701170
11711171 elif self .config .attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN :
11721172 batch , seq_len , num_heads , head_dim = query .shape
1173- updated_kv , attn_out = self .forward_serve_vllm (
1173+ attn_out , updated_kv = self .forward_serve_vllm (
11741174 query , key , value , rpa_kv_cache = kv_cache , rpa_metadata = attention_metadata
11751175 )
11761176 out = attn_out .reshape (batch , seq_len , num_heads , head_dim )
0 commit comments