Skip to content

Commit a7bc451

Browse files
authored
fix: GLM-5.1 on ROCm (#1637)
Signed-off-by: AlpinDale <alpindale@gmail.com>
1 parent b4835a5 commit a7bc451

3 files changed

Lines changed: 64 additions & 24 deletions

File tree

aphrodite/v1/attention/backends/mla/rocm_aiter_mla.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from dataclasses import dataclass
5-
from typing import ClassVar
5+
from typing import ClassVar, Final
66

77
import torch
88

@@ -358,6 +358,51 @@ def _expand_page_indices_kernel(
358358
)
359359

360360

361+
class AiterMLAHelper:
362+
"""
363+
AITER MLA implementation requires num_heads >= 16. If num_heads < 16 and
364+
16 % num_heads == 0, we can pad q to 16 heads; otherwise AITER has to fail.
365+
"""
366+
367+
_AITER_MIN_MLA_HEADS: Final = 16
368+
369+
@staticmethod
370+
def check_num_heads_validity(num_heads: int):
371+
assert AiterMLAHelper.is_valid_num_heads(num_heads), (
372+
f"Aiter MLA requires that num_heads be multiples or divisors of 16, "
373+
f"but provided {num_heads} number of heads.\n"
374+
f"Try adjusting tensor_parallel_size value."
375+
)
376+
377+
@staticmethod
378+
def is_valid_num_heads(num_heads: int) -> bool:
379+
return (
380+
num_heads % AiterMLAHelper._AITER_MIN_MLA_HEADS == 0
381+
if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
382+
else AiterMLAHelper._AITER_MIN_MLA_HEADS % num_heads == 0
383+
)
384+
385+
@staticmethod
386+
def get_actual_mla_num_heads(num_heads: int) -> int:
387+
return max(num_heads, AiterMLAHelper._AITER_MIN_MLA_HEADS)
388+
389+
@staticmethod
390+
def get_mla_padded_q(num_heads: int, q: torch.Tensor) -> torch.Tensor:
391+
return (
392+
q
393+
if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
394+
else q.repeat_interleave(AiterMLAHelper._AITER_MIN_MLA_HEADS // num_heads, dim=1)
395+
)
396+
397+
@staticmethod
398+
def get_mla_unpadded_o(num_heads: int, o: torch.Tensor) -> torch.Tensor:
399+
return (
400+
o
401+
if num_heads >= AiterMLAHelper._AITER_MIN_MLA_HEADS
402+
else o[:, :: AiterMLAHelper._AITER_MIN_MLA_HEADS // num_heads, :]
403+
)
404+
405+
361406
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
362407
def __init__(
363408
self,
@@ -387,15 +432,8 @@ def __init__(
387432
kv_sharing_target_layer_name,
388433
**mla_args,
389434
)
390-
_valid_heads = num_heads in (4, 8) or (num_heads % 16 == 0 and 16 <= num_heads <= 128)
391-
assert _valid_heads, (
392-
f"Aiter MLA supports num_heads of 4, 8, or multiples of 16 "
393-
f"in [16, 128].\n"
394-
f"Provided {num_heads} number of heads.\n"
395-
"Try adjusting tensor_parallel_size value."
396-
)
397-
self._needs_head_repeat = num_heads < 16
398-
self._head_repeat_factor = 16 // num_heads if num_heads < 16 else 1
435+
AiterMLAHelper.check_num_heads_validity(num_heads)
436+
399437
unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap]
400438
if any(unsupported_features):
401439
raise NotImplementedError(
@@ -435,15 +473,12 @@ def forward_mqa(
435473
assert isinstance(q, torch.Tensor)
436474
B = q.shape[0]
437475

438-
if self._needs_head_repeat:
439-
q = q.repeat_interleave(self._head_repeat_factor, dim=1)
440-
kernel_num_heads = 16
441-
else:
442-
kernel_num_heads = self.num_heads
476+
mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q)
477+
mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads)
443478

444479
o = torch.empty(
445480
B,
446-
kernel_num_heads,
481+
mla_num_heads,
447482
self.kv_lora_rank,
448483
dtype=attn_metadata.decode.attn_out_dtype,
449484
device=q.device,
@@ -470,7 +505,7 @@ def forward_mqa(
470505
)
471506

472507
rocm_aiter_ops.mla_decode_fwd(
473-
q,
508+
mla_padded_q,
474509
kv_buffer,
475510
o,
476511
self.scale,
@@ -482,7 +517,4 @@ def forward_mqa(
482517
**mla_kwargs,
483518
)
484519

485-
if self._needs_head_repeat:
486-
o = o[:, :: self._head_repeat_factor, :]
487-
488-
return o, None
520+
return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, o), None

aphrodite/v1/attention/backends/mla/rocm_aiter_mla_sparse.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from aphrodite.v1.attention.backends.mla.flashmla_sparse import (
2929
triton_convert_req_index_to_global_index,
3030
)
31+
from aphrodite.v1.attention.backends.mla.rocm_aiter_mla import AiterMLAHelper
3132
from aphrodite.v1.kv_cache_interface import AttentionSpec
3233

3334
if TYPE_CHECKING:
@@ -277,6 +278,8 @@ def __init__(
277278
indexer: "Indexer | None" = None,
278279
**mla_args,
279280
) -> None:
281+
AiterMLAHelper.check_num_heads_validity(num_heads)
282+
280283
self.num_heads = num_heads
281284
self.head_size = head_size
282285
self.scale = float(scale)
@@ -295,8 +298,9 @@ def _forward_bf16_kv(
295298
attn_metadata: ROCMAiterMLASparseMetadata,
296299
) -> torch.Tensor:
297300
num_tokens = q.shape[0]
301+
mla_num_heads = AiterMLAHelper.get_actual_mla_num_heads(self.num_heads)
298302
output = torch.empty(
299-
[num_tokens, self.num_heads, self.kv_lora_rank],
303+
[num_tokens, mla_num_heads, self.kv_lora_rank],
300304
dtype=q.dtype,
301305
device=q.device,
302306
)
@@ -322,7 +326,7 @@ def _forward_bf16_kv(
322326
attn_metadata.paged_kv_last_page_len,
323327
)
324328

325-
return output[:, : self.num_heads, :]
329+
return AiterMLAHelper.get_mla_unpadded_o(self.num_heads, output)
326330

327331
def forward_mqa(
328332
self,
@@ -352,6 +356,7 @@ def forward_mqa(
352356
NUM_TOPK_TOKENS=attn_metadata.topk_tokens,
353357
)
354358

355-
attn_out = self._forward_bf16_kv(q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata)
359+
mla_padded_q = AiterMLAHelper.get_mla_padded_q(self.num_heads, q)
360+
attn_out = self._forward_bf16_kv(mla_padded_q, kv_c_and_k_pe_cache, topk_indices_global, attn_metadata)
356361

357362
return attn_out, None

aphrodite/v1/attention/ops/rocm_aiter_mla_sparse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ def rocm_fp8_paged_mqa_logits(
321321
device="cuda",
322322
dtype=torch.float32,
323323
)
324+
# TODO: 1. Replace _stage1 and out_qk.sum with another fused variant;
325+
# 2. Remove ChunkQ when AITER PR #2891 merged
324326
deepgemm_fp8_paged_mqa_logits_stage1(
325327
q_fp8,
326328
kv_cache_fp8,
@@ -329,6 +331,7 @@ def rocm_fp8_paged_mqa_logits(
329331
context_lens,
330332
block_tables,
331333
max_model_len,
334+
ChunkQ=heads,
332335
)
333336
return out_qk.sum(dim=0)
334337
else:

0 commit comments

Comments
 (0)