Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 76 additions & 90 deletions aphrodite/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,12 @@ def get_attn_backend_cls(
"FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set "
"APHRODITE_MLA_DISABLE=1 to disable MLA for this model."
)
if not use_v1:
raise RuntimeError(
"MLA attention backends require the V1 engine. Set APHRODITE_USE_V1=1 to enable them."
)

from aphrodite.attention.ops.flashmla import is_flashmla_dense_supported
from aphrodite.attention.utils.fa_utils import flash_attn_supports_mla

if use_sparse:
logger.info_once("Using Sparse MLA backend on V1 engine.", scope="global")
logger.info_once("Using Sparse MLA backend.", scope="global")
return "aphrodite.v1.attention.backends.mla.flashmla_sparse.FlashMLASparseBackend"

use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
Expand All @@ -281,13 +277,13 @@ def get_attn_backend_cls(
use_triton = selected_backend == _Backend.TRITON_MLA or (selected_backend is None)

if use_cutlassmla:
logger.info_once("Using Cutlass MLA backend on V1 engine.", scope="local")
logger.info_once("Using Cutlass MLA backend.", scope="local")
return "aphrodite.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
if use_flashinfermla:
from aphrodite.v1.attention.backends.utils import set_kv_cache_layout

set_kv_cache_layout("HND")
logger.info_once("Using FlashInfer MLA backend on V1 engine.", scope="global")
logger.info_once("Using FlashInfer MLA backend.", scope="global")
return "aphrodite.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
if use_flashmla:
if block_size % 64 != 0:
Expand All @@ -296,106 +292,96 @@ def get_attn_backend_cls(
block_size,
)
else:
logger.info_once("Using FlashMLA backend on V1 engine.", scope="global")
logger.info_once("Using FlashMLA backend.", scope="global")
return "aphrodite.v1.attention.backends.mla.flashmla.FlashMLABackend"
if use_flashattn:
logger.info_once("Using FlashAttention MLA backend on V1 engine.", scope="global")
logger.info_once("Using FlashAttention MLA backend.", scope="global")
return "aphrodite.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
if use_triton:
logger.info_once("Using Triton MLA backend on V1 engine.", scope="global")
logger.info_once("Using Triton MLA backend.", scope="global")
return "aphrodite.v1.attention.backends.mla.triton_mla.TritonMLABackend"
if use_v1:
FLASHINFER_V1 = "aphrodite.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "aphrodite.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN = "aphrodite.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "aphrodite.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "aphrodite.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
XFORMERS_V1 = "aphrodite.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501

use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8")
FLASHINFER_V1 = "aphrodite.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = "aphrodite.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
TRITON_ATTN = "aphrodite.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "aphrodite.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "aphrodite.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
XFORMERS_V1 = "aphrodite.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501

if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.", scope="global")
if cls.has_device_capability(100):
from aphrodite.v1.attention.backends.utils import set_kv_cache_layout
use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8")

set_kv_cache_layout("HND")
return FLASHINFER_V1
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info_once("Using FlexAttention backend on V1 engine.", scope="global")
return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN:
logger.info_once("Using Triton backend on V1 engine.", scope="global")
return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend on V1 engine.", scope="global")
return FLASH_ATTN_V1
elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend on V1 engine.", scope="global")
return TREE_ATTN_V1
elif selected_backend == _Backend.XFORMERS:
logger.info_once("Using XFormers backend on V1 engine.", scope="global")
return XFORMERS_V1

from aphrodite.attention.selector import is_attn_backend_supported

# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100):
if is_default_backend_supported := is_attn_backend_supported(FLASHINFER_V1, head_size, dtype):
from aphrodite.v1.attention.backends.utils import set_kv_cache_layout
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend.")
if cls.has_device_capability(100):
from aphrodite.v1.attention.backends.utils import set_kv_cache_layout

logger.info_once(
"Using FlashInfer backend with HND KV cache layout on "
"V1 engine by default for Blackwell (SM 10.0) GPUs.",
scope="global",
)
set_kv_cache_layout("HND")
set_kv_cache_layout("HND")
return FLASHINFER_V1
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info_once("Using FlexAttention backend.")
return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN:
logger.info_once("Using Triton backend.")
return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN_V1
elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend.")
return TREE_ATTN_V1
elif selected_backend == _Backend.XFORMERS:
logger.info_once("Using XFormers backend.")
return XFORMERS_V1

from aphrodite.attention.selector import is_attn_backend_supported

# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100):
if is_default_backend_supported := is_attn_backend_supported(FLASHINFER_V1, head_size, dtype):
from aphrodite.v1.attention.backends.utils import set_kv_cache_layout

return FLASHINFER_V1
logger.info_once(
"Using FlashInfer backend with HND KV cache layout on "
"V1 engine by default for Blackwell (SM 10.0) GPUs.",
scope="global",
)
set_kv_cache_layout("HND")

if not is_default_backend_supported.can_import:
logger.warning_once(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.",
scope="global",
)
return FLASHINFER_V1

# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
logger.info_once("Using Triton backend on V1 engine.", scope="global")
return TRITON_ATTN
elif is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
):
logger.info_once("Using Flash Attention backend on V1 engine.", scope="global")
return FLASH_ATTN_V1

# FlexAttention is the default for older GPUs
else:
logger.info_once("Using FlexAttention backend on V1 engine.", scope="global")
return FLEX_ATTENTION_V1
if not is_default_backend_supported.can_import:
logger.warning_once(
"FlashInfer failed to import for V1 engine on "
"Blackwell (SM 10.0) GPUs; it is recommended to "
"install FlashInfer for better performance.",
scope="global",
)

assert not is_default_backend_supported
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
logger.info_once("Using Triton backend.", scope="global")
return TRITON_ATTN
elif is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
):
logger.info_once("Using Flash Attention backend.", scope="global")
return FLASH_ATTN_V1

use_flex_attention_reason = {}
if not is_default_backend_supported.head_size:
use_flex_attention_reason["head_size"] = head_size
if not is_default_backend_supported.dtype:
use_flex_attention_reason["dtype"] = dtype
assert not is_default_backend_supported

logger.info_once(
"Using FlexAttention backend for %s on V1 engine.",
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
scope="global",
)
return FLEX_ATTENTION_V1
use_flex_attention_reason = {}
if not is_default_backend_supported.head_size:
use_flex_attention_reason["head_size"] = head_size
if not is_default_backend_supported.dtype:
use_flex_attention_reason["dtype"] = dtype

raise RuntimeError(
"V0 attention backends have been removed. Set APHRODITE_USE_V1=1 to select a supported backend."
logger.info_once(
"Using FlexAttention backend for %s.",
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
)
Comment on lines +380 to 383

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The scope="global" parameter has been removed from this logger.info_once call. This changes the logging behavior from logging once per process to once per worker (the default), which can lead to duplicated log messages in a multi-worker environment. This appears to be an oversight, as other similar logging calls in this file retained scope="global".

Suggested change
logger.info_once(
"Using FlexAttention backend for %s.",
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
)
logger.info_once(
"Using FlexAttention backend for %s.",
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
scope="global",
)

return FLEX_ATTENTION_V1

@classmethod
def get_punica_wrapper(cls) -> str:
Expand Down
6 changes: 1 addition & 5 deletions aphrodite/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,11 +453,7 @@ def use_all_gather(cls) -> bool:
"""
Whether to use allgather in LogitsProcessor to gather the logits.
"""
import aphrodite.envs as envs
from aphrodite.config import get_current_aphrodite_config

parallel_config = get_current_aphrodite_config().parallel_config
return envs.APHRODITE_USE_V1 or parallel_config.distributed_executor_backend == "external_launcher"
return True

@classmethod
def use_custom_allreduce(cls) -> bool:
Expand Down
67 changes: 26 additions & 41 deletions aphrodite/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def use_rocm_custom_paged_attention(
# disabled due to observed numerical discrepancy.
if ON_GFX9:
return (
(not envs.APHRODITE_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1))
(sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
Expand All @@ -160,7 +160,7 @@ def use_rocm_custom_paged_attention(
else:
return (
ON_GFX11_GFX12
and (not envs.APHRODITE_USE_V1 or sliding_window == 0 or sliding_window == (-1, -1))
and (sliding_window == 0 or sliding_window == (-1, -1))
and (qtype == torch.half or qtype == torch.bfloat16)
and head_size == 128
and block_size == 16
Expand Down Expand Up @@ -229,11 +229,6 @@ def get_attn_backend_cls(
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
if use_mla:
if not use_v1:
raise RuntimeError(
"MLA attention backends require the V1 engine. Set APHRODITE_USE_V1=1 to enable them."
)

from aphrodite.v1.attention.backends.mla.rocm_aiter_mla import is_aiter_mla_enabled

if selected_backend is None:
Expand All @@ -243,14 +238,14 @@ def get_attn_backend_cls(

if selected_backend == _Backend.TRITON_MLA:
if block_size != 1:
logger.info_once("Using Triton MLA backend on V1 engine.")
logger.info_once("Using Triton MLA backend.", scope="global")
return "aphrodite.v1.attention.backends.mla.triton_mla.TritonMLABackend"
raise ValueError(
f" The selected backend, {selected_backend.name},does not support block size {block_size}."
)
if selected_backend == _Backend.ROCM_AITER_MLA:
if block_size == 1:
logger.info("Using AITER MLA backend on V1 engine.")
logger.info_once("Using AITER MLA backend.", scope="global")
return "aphrodite.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
raise ValueError(
f" The selected backend, {selected_backend.name},"
Expand All @@ -261,31 +256,27 @@ def get_attn_backend_cls(
f" The selected backend, {selected_backend.name},is not MLA type while requested for MLA backend."
)

if envs.APHRODITE_USE_V1:
if selected_backend == _Backend.FLEX_ATTENTION:
logger.info("Using FlexAttention backend on V1 engine.")
return "aphrodite.v1.attention.backends.flex_attention.FlexAttentionBackend"
if (
envs.APHRODITE_ROCM_USE_AITER and envs.APHRODITE_ROCM_USE_AITER_MHA and on_gfx9()
) or selected_backend == _Backend.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend on V1 engine.")
return "aphrodite.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
if (
envs.APHRODITE_ROCM_USE_AITER and envs.APHRODITE_ROCM_USE_AITER_UNIFIED_ATTENTION
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend on V1 engine.")
return "aphrodite.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
if envs.APHRODITE_V1_USE_PREFILL_DECODE_ATTENTION or selected_backend == _Backend.ROCM_ATTN:
# rocm specific backend, with aiter and/or
# triton prefix-prefill
logger.info("Using Rocm Attention backend on V1 engine.")
return "aphrodite.v1.attention.backends.rocm_attn.RocmAttentionBackend"
# default case, using triton unified attention
logger.info("Using Triton Attention backend on V1 engine.")
return "aphrodite.v1.attention.backends.triton_attn.TritonAttentionBackend"
raise RuntimeError(
"V0 attention backends have been removed. Set APHRODITE_USE_V1=1 to select a supported backend."
)
if selected_backend == _Backend.FLEX_ATTENTION:
logger.info("Using FlexAttention backend.")
return "aphrodite.v1.attention.backends.flex_attention.FlexAttentionBackend"
if (
envs.APHRODITE_ROCM_USE_AITER and envs.APHRODITE_ROCM_USE_AITER_MHA and on_gfx9()
) or selected_backend == _Backend.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend.")
return "aphrodite.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
if (
envs.APHRODITE_ROCM_USE_AITER and envs.APHRODITE_ROCM_USE_AITER_UNIFIED_ATTENTION
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.")
return "aphrodite.v1.attention.backends.rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
if envs.APHRODITE_V1_USE_PREFILL_DECODE_ATTENTION or selected_backend == _Backend.ROCM_ATTN:
# rocm specific backend, with aiter and/or
# triton prefix-prefill
logger.info("Using Rocm Attention backend.")
return "aphrodite.v1.attention.backends.rocm_attn.RocmAttentionBackend"
# default case, using triton unified attention
logger.info("Using Triton Attention backend.")
return "aphrodite.v1.attention.backends.triton_attn.TritonAttentionBackend"

@classmethod
def set_device(cls, device: torch.device) -> None:
Expand Down Expand Up @@ -346,7 +337,6 @@ def check_and_update_config(cls, aphrodite_config: "AphroditeConfig") -> None:
parallel_config = aphrodite_config.parallel_config
is_eager_execution = compilation_config == CUDAGraphMode.NONE

use_v1 = envs.APHRODITE_USE_V1
use_aiter_rms_norm = envs.APHRODITE_ROCM_USE_AITER and envs.APHRODITE_ROCM_USE_AITER_RMSNORM

if cache_config and cache_config.block_size is None:
Expand All @@ -355,12 +345,7 @@ def check_and_update_config(cls, aphrodite_config: "AphroditeConfig") -> None:
if parallel_config.worker_cls == "auto":
parallel_config.worker_cls = "aphrodite.v1.worker.gpu_worker.Worker"
# Aiter rms norm perform best when CUDA Graph capture is enabled.
if (
use_v1
and use_aiter_rms_norm
and not is_eager_execution
and "-rms_norm" not in compilation_config.custom_ops
):
if use_aiter_rms_norm and not is_eager_execution and "-rms_norm" not in compilation_config.custom_ops:
compilation_config.custom_ops.append("+rms_norm")

@classmethod
Expand Down
4 changes: 0 additions & 4 deletions aphrodite/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,6 @@ def is_pin_memory_available(cls):
def get_device_communicator_cls(cls) -> str:
return "aphrodite.distributed.device_communicators.tpu_communicator.TpuCommunicator" # noqa

@classmethod
def use_all_gather(cls) -> bool:
return True

@classmethod
def validate_request(
cls,
Expand Down
Loading