Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,6 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
list(APPEND APHRODITE_MOE_EXT_SRC
"csrc/moe/moe_wna16.cu"
"csrc/moe/grouped_topk_kernels.cu"
"csrc/moe/router_gemm.cu"
"csrc/moe/topk_softplus_sqrt_kernels.cu")
endif()

Expand Down
124 changes: 122 additions & 2 deletions aphrodite/_aiter_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import functools
from collections.abc import Callable
from contextlib import contextmanager
from typing import Protocol

import torch
from torch._ops import OpOverload
from torch.distributed import ProcessGroup

import aphrodite.envs as envs
from aphrodite.platforms import current_platform
Expand Down Expand Up @@ -39,6 +42,27 @@ def is_aiter_found() -> bool:
IS_AITER_FOUND = is_aiter_found()


class AiterCustomAllreduceProto(Protocol):
max_size: int
world_size: int
fully_connected: bool

@contextmanager
def capture(self): ...
def close(self) -> None: ...
def fused_ar_rms(
self,
inp: torch.Tensor,
res_inp: torch.Tensor,
*,
w: torch.Tensor,
eps: float,
registered: bool = False,
use_1stage: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ...
def should_custom_ar(self, inp: torch.Tensor) -> bool: ...


def is_aiter_found_and_supported() -> bool:
"""Check if AITER library is available and platform supports it.

Expand Down Expand Up @@ -731,6 +755,55 @@ def _rocm_aiter_per_tensor_quant_impl(
return per_tensor_quant_hip(x, scale, quant_dtype)


def _rocm_aiter_fused_allreduce_rmsnorm_impl(
input_: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
aiter_ar = rocm_aiter_ops.get_aiter_allreduce()
assert aiter_ar is not None, "aiter allreduce must be initialized"

total_bytes = input_.numel() * input_.element_size()
hidden_dim = input_.shape[-1]
token_num = input_.shape[0]
hidden_ok = hidden_dim in (512, 1024, 2048, 4096, 7168)
token_ok = token_num <= 80
world_size = aiter_ar.world_size
full_nvlink = aiter_ar.fully_connected

if world_size == 2:
size_ok = True
elif full_nvlink and world_size <= 4:
size_ok = total_bytes < 256 * 1024
elif full_nvlink and world_size <= 8:
size_ok = total_bytes < 128 * 1024
else:
size_ok = False

use_1stage = hidden_ok and token_ok and size_ok

result = aiter_ar.fused_ar_rms(
input_,
residual,
w=weight,
eps=epsilon,
registered=torch.cuda.is_current_stream_capturing(),
use_1stage=use_1stage,
)
assert result is not None
return result[0], result[1]


def _rocm_aiter_fused_allreduce_rmsnorm_fake(
input_: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
epsilon: float,
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(input_), torch.empty_like(residual)


def _rocm_aiter_per_tensor_quant_fake(
x: torch.Tensor,
quant_dtype: torch.dtype,
Expand All @@ -747,7 +820,7 @@ def _rocm_aiter_per_token_quant_impl(
assert quant_dtype in [torch.int8, FP8_DTYPE]

out_shape = x.shape
out = torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device)
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
if scale is None:
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
dynamic_per_token_scaled_quant(
Expand All @@ -767,7 +840,7 @@ def _rocm_aiter_per_token_quant_fake(
) -> tuple[torch.Tensor, torch.Tensor]:
out_shape = x.shape
return (
torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device),
torch.empty(x.shape, dtype=quant_dtype, device=x.device),
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
)

Expand Down Expand Up @@ -1157,6 +1230,9 @@ class rocm_aiter_ops:
# TODO: Consolidate under _LINEAR_ENABLED
_TRITON_UNQUANT_GEMM = envs.APHRODITE_ROCM_USE_AITER_TRITON_GEMM

_ALL_REDUCE_MAX_SIZE: int = 8192 * 1024 * 8 * 2
_CUSTOM_ALL_REDUCE: AiterCustomAllreduceProto | None = None

@classmethod
def refresh_env_variables(cls):
"""
Expand Down Expand Up @@ -1324,6 +1400,40 @@ def is_triton_rotary_embed_enabled(cls) -> bool:
def is_triton_gemm_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM

@classmethod
@if_aiter_supported
def is_tgemm_enabled(cls) -> bool:
from aphrodite.platforms.rocm import on_gfx950

return cls.is_linear_enabled() and on_gfx950()

@classmethod
def initialize_aiter_allreduce(cls, group: ProcessGroup, device: torch.device) -> None:
try:
from aiter.dist.device_communicators.custom_all_reduce import (
CustomAllreduce as AiterCustomAllreduce,
)

cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device)
except Exception:
cls._CUSTOM_ALL_REDUCE = None

@classmethod
def get_aiter_allreduce(cls) -> AiterCustomAllreduceProto | None:
return cls._CUSTOM_ALL_REDUCE

@classmethod
def destroy_aiter_allreduce(cls) -> None:
if cls._CUSTOM_ALL_REDUCE is not None:
cls._CUSTOM_ALL_REDUCE.close()
cls._CUSTOM_ALL_REDUCE = None

@classmethod
def get_aiter_allreduce_max_size(cls) -> int | None:
# effective max input size (based on upstream aiter version: v0.1.10.post3)
# https://github.com/ROCm/aiter/blob/6a0e7b26ccf33164785531212cc2ec2cde0b9243/aiter/dist/device_communicators/custom_all_reduce.py#L272-L273
return int(cls._ALL_REDUCE_MAX_SIZE / 2)

@staticmethod
@if_aiter_supported
def register_ops_once() -> None:
Expand Down Expand Up @@ -1514,6 +1624,12 @@ def register_ops_once() -> None:
fake_impl=_triton_rotary_embedding_fake,
)

direct_register_custom_op(
op_name="rocm_aiter_fused_allreduce_rmsnorm",
op_func=_rocm_aiter_fused_allreduce_rmsnorm_impl,
fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake,
)

direct_register_custom_op(
op_name="fused_mla_dual_rms_norm",
op_func=_fused_mla_dual_rms_norm_impl,
Expand Down Expand Up @@ -1567,6 +1683,10 @@ def get_triton_add_rmsnorm_pad_op() -> OpOverload:
def get_triton_rotary_embedding_op() -> OpOverload:
return torch.ops.aphrodite.rocm_aiter_triton_rotary_embedding.default

@staticmethod
def get_fused_allreduce_rmsnorm_op() -> OpOverload:
return torch.ops.aphrodite.rocm_aiter_fused_allreduce_rmsnorm.default

@staticmethod
def get_fused_mla_dual_rms_norm_op() -> OpOverload:
return torch.ops.aphrodite.fused_mla_dual_rms_norm.default
Expand Down
27 changes: 12 additions & 15 deletions aphrodite/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2632,21 +2632,6 @@ def moe_wna16_gemm(
)


def router_gemm_bf16_fp32(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
"""bf16 x bf16 -> fp32 GEMM via cuBLAS. weight shape: (N, K)."""
return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight)


if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"):

@register_fake("_moe_C::router_gemm_bf16_fp32")
def router_gemm_bf16_fp32_fake(
input: torch.Tensor,
weight: torch.Tensor,
) -> torch.Tensor:
return torch.empty(input.shape[0], weight.shape[0], dtype=torch.float32, device=input.device)


def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weight: torch.Tensor,
Expand Down Expand Up @@ -3552,6 +3537,9 @@ def cpu_attn_reshape_and_cache(
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
isa: str,
k_scale: float = 1.0,
v_scale: float = 1.0,
kv_cache_dtype: str = "auto",
) -> None:
torch.ops._C.cpu_attn_reshape_and_cache(
key,
Expand All @@ -3560,6 +3548,9 @@ def cpu_attn_reshape_and_cache(
value_cache,
slot_mapping,
isa,
k_scale,
v_scale,
kv_cache_dtype,
)


Expand All @@ -3578,6 +3569,9 @@ def cpu_attention_with_kv_cache(
softcap: float,
scheduler_metadata: torch.Tensor,
s_aux: torch.Tensor | None,
k_scale: float = 1.0,
v_scale: float = 1.0,
kv_cache_dtype: str = "auto",
) -> None:
torch.ops._C.cpu_attention_with_kv_cache(
query,
Expand All @@ -3595,6 +3589,9 @@ def cpu_attention_with_kv_cache(
softcap,
scheduler_metadata,
s_aux,
k_scale,
v_scale,
kv_cache_dtype,
)


Expand Down
3 changes: 2 additions & 1 deletion aphrodite/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def compile(
compilation_counter.num_backend_compilations += 1

compiled_graph = None
handle = None

# try to load from the cache
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
Expand Down Expand Up @@ -342,7 +343,7 @@ def autograd_cache_key(*args, **kwargs):
)
except StopCompiling:
assert cache_key is not None
return self.loaded_artifacts[cache_key]
compiled_graph = self.loaded_artifacts[cache_key]
if cache_key is not None and compiled_graph is not None:
self.loaded_artifacts[cache_key] = compiled_graph

Expand Down
9 changes: 7 additions & 2 deletions aphrodite/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
# across layers will make the cudagraph capture very slow.
# therefore, we only run gc for the first graph,
# and disable gc for the rest of the graphs.
stack.enter_context(patch("gc.collect", lambda: None))
stack.enter_context(patch("torch.accelerator.empty_cache", lambda: None))
stack.enter_context(patch("gc.collect", lambda *args, **kwargs: None))
stack.enter_context(
patch(
"torch.accelerator.empty_cache",
lambda *args, **kwargs: None,
)
)

if self.graph_pool is not None:
set_graph_pool_id(self.graph_pool)
Expand Down
Loading
Loading