22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
44from dataclasses import dataclass
5- from typing import ClassVar
5+ from typing import ClassVar , Final
66
77import torch
88
@@ -358,6 +358,51 @@ def _expand_page_indices_kernel(
358358 )
359359
360360
361+ class AiterMLAHelper :
362+ """
363+ AITER MLA implementation requires num_heads >= 16. If num_heads < 16 and
364+ 16 % num_heads == 0, we can pad q to 16 heads; otherwise AITER has to fail.
365+ """
366+
367+ _AITER_MIN_MLA_HEADS : Final = 16
368+
369+ @staticmethod
370+ def check_num_heads_validity (num_heads : int ):
371+ assert AiterMLAHelper .is_valid_num_heads (num_heads ), (
372+ f"Aiter MLA requires that num_heads be multiples or divisors of 16, "
373+ f"but provided { num_heads } number of heads.\n "
374+ f"Try adjusting tensor_parallel_size value."
375+ )
376+
377+ @staticmethod
378+ def is_valid_num_heads (num_heads : int ) -> bool :
379+ return (
380+ num_heads % AiterMLAHelper ._AITER_MIN_MLA_HEADS == 0
381+ if num_heads >= AiterMLAHelper ._AITER_MIN_MLA_HEADS
382+ else AiterMLAHelper ._AITER_MIN_MLA_HEADS % num_heads == 0
383+ )
384+
385+ @staticmethod
386+ def get_actual_mla_num_heads (num_heads : int ) -> int :
387+ return max (num_heads , AiterMLAHelper ._AITER_MIN_MLA_HEADS )
388+
389+ @staticmethod
390+ def get_mla_padded_q (num_heads : int , q : torch .Tensor ) -> torch .Tensor :
391+ return (
392+ q
393+ if num_heads >= AiterMLAHelper ._AITER_MIN_MLA_HEADS
394+ else q .repeat_interleave (AiterMLAHelper ._AITER_MIN_MLA_HEADS // num_heads , dim = 1 )
395+ )
396+
397+ @staticmethod
398+ def get_mla_unpadded_o (num_heads : int , o : torch .Tensor ) -> torch .Tensor :
399+ return (
400+ o
401+ if num_heads >= AiterMLAHelper ._AITER_MIN_MLA_HEADS
402+ else o [:, :: AiterMLAHelper ._AITER_MIN_MLA_HEADS // num_heads , :]
403+ )
404+
405+
361406class AiterMLAImpl (MLACommonImpl [AiterMLAMetadata ]):
362407 def __init__ (
363408 self ,
@@ -387,15 +432,8 @@ def __init__(
387432 kv_sharing_target_layer_name ,
388433 ** mla_args ,
389434 )
390- _valid_heads = num_heads in (4 , 8 ) or (num_heads % 16 == 0 and 16 <= num_heads <= 128 )
391- assert _valid_heads , (
392- f"Aiter MLA supports num_heads of 4, 8, or multiples of 16 "
393- f"in [16, 128].\n "
394- f"Provided { num_heads } number of heads.\n "
395- "Try adjusting tensor_parallel_size value."
396- )
397- self ._needs_head_repeat = num_heads < 16
398- self ._head_repeat_factor = 16 // num_heads if num_heads < 16 else 1
435+ AiterMLAHelper .check_num_heads_validity (num_heads )
436+
399437 unsupported_features = [alibi_slopes , sliding_window , logits_soft_cap ]
400438 if any (unsupported_features ):
401439 raise NotImplementedError (
@@ -435,15 +473,12 @@ def forward_mqa(
435473 assert isinstance (q , torch .Tensor )
436474 B = q .shape [0 ]
437475
438- if self ._needs_head_repeat :
439- q = q .repeat_interleave (self ._head_repeat_factor , dim = 1 )
440- kernel_num_heads = 16
441- else :
442- kernel_num_heads = self .num_heads
476+ mla_padded_q = AiterMLAHelper .get_mla_padded_q (self .num_heads , q )
477+ mla_num_heads = AiterMLAHelper .get_actual_mla_num_heads (self .num_heads )
443478
444479 o = torch .empty (
445480 B ,
446- kernel_num_heads ,
481+ mla_num_heads ,
447482 self .kv_lora_rank ,
448483 dtype = attn_metadata .decode .attn_out_dtype ,
449484 device = q .device ,
@@ -470,7 +505,7 @@ def forward_mqa(
470505 )
471506
472507 rocm_aiter_ops .mla_decode_fwd (
473- q ,
508+ mla_padded_q ,
474509 kv_buffer ,
475510 o ,
476511 self .scale ,
@@ -482,7 +517,4 @@ def forward_mqa(
482517 ** mla_kwargs ,
483518 )
484519
485- if self ._needs_head_repeat :
486- o = o [:, :: self ._head_repeat_factor , :]
487-
488- return o , None
520+ return AiterMLAHelper .get_mla_unpadded_o (self .num_heads , o ), None
0 commit comments