Skip to content

Commit c84d94c

Browse files
committed
fix(vllm): Fix PagedAttention return signature unpacking and unrolled scan loop
1 parent ba92ee5 commit c84d94c

1 file changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxtext/layers/attentions.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -981,7 +981,7 @@ def forward_serve_vllm(
981981
value: Array,
982982
rpa_kv_cache: list[Array] | None = None,
983983
rpa_metadata: dict[str, Any] | None = None,
984-
) -> tuple[list[Array], Array]:
984+
) -> tuple[Array, list[Array]]:
985985
"""Forward function for vLLM serving with RPA attention."""
986986
try:
987987
# pylint: disable=import-outside-toplevel
@@ -998,7 +998,7 @@ def forward_serve_vllm(
998998

999999
if rpa_kv_cache is None or rpa_metadata is None:
10001000
# Return dummy values for dry runs (e.g. during model initialization or JIT tracing)
1001-
return [], query
1001+
return query, []
10021002

10031003
if self.config.sliding_window_size > 0:
10041004
attention_chunk_size = self.config.sliding_window_size
@@ -1027,7 +1027,7 @@ def forward_serve_vllm(
10271027
k_scale,
10281028
v_scale,
10291029
)
1030-
return kv_cache, output
1030+
return output, kv_cache
10311031

10321032
def __call__(
10331033
self,
@@ -1170,7 +1170,7 @@ def __call__(
11701170

11711171
elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN:
11721172
batch, seq_len, num_heads, head_dim = query.shape
1173-
updated_kv, attn_out = self.forward_serve_vllm(
1173+
attn_out, updated_kv = self.forward_serve_vllm(
11741174
query, key, value, rpa_kv_cache=kv_cache, rpa_metadata=attention_metadata
11751175
)
11761176
out = attn_out.reshape(batch, seq_len, num_heads, head_dim)

0 commit comments

Comments
 (0)