Skip to content

Commit 18f852d

Browse files
authored
chore: sync to upstream 985961345a13f3e3bb15d29c94b011ba9a6b858b (#1666)
1 parent e958628 commit 18f852d

263 files changed

Lines changed: 19152 additions & 4404 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -948,7 +948,6 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
948948
list(APPEND APHRODITE_MOE_EXT_SRC
949949
"csrc/moe/moe_wna16.cu"
950950
"csrc/moe/grouped_topk_kernels.cu"
951-
"csrc/moe/router_gemm.cu"
952951
"csrc/moe/topk_softplus_sqrt_kernels.cu")
953952
endif()
954953

aphrodite/_aiter_ops.py

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
import functools
44
from collections.abc import Callable
5+
from contextlib import contextmanager
6+
from typing import Protocol
57

68
import torch
79
from torch._ops import OpOverload
10+
from torch.distributed import ProcessGroup
811

912
import aphrodite.envs as envs
1013
from aphrodite.platforms import current_platform
@@ -39,6 +42,27 @@ def is_aiter_found() -> bool:
3942
IS_AITER_FOUND = is_aiter_found()
4043

4144

45+
class AiterCustomAllreduceProto(Protocol):
46+
max_size: int
47+
world_size: int
48+
fully_connected: bool
49+
50+
@contextmanager
51+
def capture(self): ...
52+
def close(self) -> None: ...
53+
def fused_ar_rms(
54+
self,
55+
inp: torch.Tensor,
56+
res_inp: torch.Tensor,
57+
*,
58+
w: torch.Tensor,
59+
eps: float,
60+
registered: bool = False,
61+
use_1stage: bool = False,
62+
) -> tuple[torch.Tensor, torch.Tensor]: ...
63+
def should_custom_ar(self, inp: torch.Tensor) -> bool: ...
64+
65+
4266
def is_aiter_found_and_supported() -> bool:
4367
"""Check if AITER library is available and platform supports it.
4468
@@ -731,6 +755,55 @@ def _rocm_aiter_per_tensor_quant_impl(
731755
return per_tensor_quant_hip(x, scale, quant_dtype)
732756

733757

758+
def _rocm_aiter_fused_allreduce_rmsnorm_impl(
759+
input_: torch.Tensor,
760+
residual: torch.Tensor,
761+
weight: torch.Tensor,
762+
epsilon: float,
763+
) -> tuple[torch.Tensor, torch.Tensor]:
764+
aiter_ar = rocm_aiter_ops.get_aiter_allreduce()
765+
assert aiter_ar is not None, "aiter allreduce must be initialized"
766+
767+
total_bytes = input_.numel() * input_.element_size()
768+
hidden_dim = input_.shape[-1]
769+
token_num = input_.shape[0]
770+
hidden_ok = hidden_dim in (512, 1024, 2048, 4096, 7168)
771+
token_ok = token_num <= 80
772+
world_size = aiter_ar.world_size
773+
full_nvlink = aiter_ar.fully_connected
774+
775+
if world_size == 2:
776+
size_ok = True
777+
elif full_nvlink and world_size <= 4:
778+
size_ok = total_bytes < 256 * 1024
779+
elif full_nvlink and world_size <= 8:
780+
size_ok = total_bytes < 128 * 1024
781+
else:
782+
size_ok = False
783+
784+
use_1stage = hidden_ok and token_ok and size_ok
785+
786+
result = aiter_ar.fused_ar_rms(
787+
input_,
788+
residual,
789+
w=weight,
790+
eps=epsilon,
791+
registered=torch.cuda.is_current_stream_capturing(),
792+
use_1stage=use_1stage,
793+
)
794+
assert result is not None
795+
return result[0], result[1]
796+
797+
798+
def _rocm_aiter_fused_allreduce_rmsnorm_fake(
799+
input_: torch.Tensor,
800+
residual: torch.Tensor,
801+
weight: torch.Tensor,
802+
epsilon: float,
803+
) -> tuple[torch.Tensor, torch.Tensor]:
804+
return torch.empty_like(input_), torch.empty_like(residual)
805+
806+
734807
def _rocm_aiter_per_tensor_quant_fake(
735808
x: torch.Tensor,
736809
quant_dtype: torch.dtype,
@@ -747,7 +820,7 @@ def _rocm_aiter_per_token_quant_impl(
747820
assert quant_dtype in [torch.int8, FP8_DTYPE]
748821

749822
out_shape = x.shape
750-
out = torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device)
823+
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
751824
if scale is None:
752825
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
753826
dynamic_per_token_scaled_quant(
@@ -767,7 +840,7 @@ def _rocm_aiter_per_token_quant_fake(
767840
) -> tuple[torch.Tensor, torch.Tensor]:
768841
out_shape = x.shape
769842
return (
770-
torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device),
843+
torch.empty(x.shape, dtype=quant_dtype, device=x.device),
771844
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
772845
)
773846

@@ -1157,6 +1230,9 @@ class rocm_aiter_ops:
11571230
# TODO: Consolidate under _LINEAR_ENABLED
11581231
_TRITON_UNQUANT_GEMM = envs.APHRODITE_ROCM_USE_AITER_TRITON_GEMM
11591232

1233+
_ALL_REDUCE_MAX_SIZE: int = 8192 * 1024 * 8 * 2
1234+
_CUSTOM_ALL_REDUCE: AiterCustomAllreduceProto | None = None
1235+
11601236
@classmethod
11611237
def refresh_env_variables(cls):
11621238
"""
@@ -1324,6 +1400,40 @@ def is_triton_rotary_embed_enabled(cls) -> bool:
13241400
def is_triton_gemm_enabled(cls) -> bool:
13251401
return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM
13261402

1403+
@classmethod
1404+
@if_aiter_supported
1405+
def is_tgemm_enabled(cls) -> bool:
1406+
from aphrodite.platforms.rocm import on_gfx950
1407+
1408+
return cls.is_linear_enabled() and on_gfx950()
1409+
1410+
@classmethod
1411+
def initialize_aiter_allreduce(cls, group: ProcessGroup, device: torch.device) -> None:
1412+
try:
1413+
from aiter.dist.device_communicators.custom_all_reduce import (
1414+
CustomAllreduce as AiterCustomAllreduce,
1415+
)
1416+
1417+
cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device)
1418+
except Exception:
1419+
cls._CUSTOM_ALL_REDUCE = None
1420+
1421+
@classmethod
1422+
def get_aiter_allreduce(cls) -> AiterCustomAllreduceProto | None:
1423+
return cls._CUSTOM_ALL_REDUCE
1424+
1425+
@classmethod
1426+
def destroy_aiter_allreduce(cls) -> None:
1427+
if cls._CUSTOM_ALL_REDUCE is not None:
1428+
cls._CUSTOM_ALL_REDUCE.close()
1429+
cls._CUSTOM_ALL_REDUCE = None
1430+
1431+
@classmethod
1432+
def get_aiter_allreduce_max_size(cls) -> int | None:
1433+
# effective max input size (based on upstream aiter version: v0.1.10.post3)
1434+
# https://github.com/ROCm/aiter/blob/6a0e7b26ccf33164785531212cc2ec2cde0b9243/aiter/dist/device_communicators/custom_all_reduce.py#L272-L273
1435+
return int(cls._ALL_REDUCE_MAX_SIZE / 2)
1436+
13271437
@staticmethod
13281438
@if_aiter_supported
13291439
def register_ops_once() -> None:
@@ -1514,6 +1624,12 @@ def register_ops_once() -> None:
15141624
fake_impl=_triton_rotary_embedding_fake,
15151625
)
15161626

1627+
direct_register_custom_op(
1628+
op_name="rocm_aiter_fused_allreduce_rmsnorm",
1629+
op_func=_rocm_aiter_fused_allreduce_rmsnorm_impl,
1630+
fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake,
1631+
)
1632+
15171633
direct_register_custom_op(
15181634
op_name="fused_mla_dual_rms_norm",
15191635
op_func=_fused_mla_dual_rms_norm_impl,
@@ -1567,6 +1683,10 @@ def get_triton_add_rmsnorm_pad_op() -> OpOverload:
15671683
def get_triton_rotary_embedding_op() -> OpOverload:
15681684
return torch.ops.aphrodite.rocm_aiter_triton_rotary_embedding.default
15691685

1686+
@staticmethod
1687+
def get_fused_allreduce_rmsnorm_op() -> OpOverload:
1688+
return torch.ops.aphrodite.rocm_aiter_fused_allreduce_rmsnorm.default
1689+
15701690
@staticmethod
15711691
def get_fused_mla_dual_rms_norm_op() -> OpOverload:
15721692
return torch.ops.aphrodite.fused_mla_dual_rms_norm.default

aphrodite/_custom_ops.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2632,21 +2632,6 @@ def moe_wna16_gemm(
26322632
)
26332633

26342634

2635-
def router_gemm_bf16_fp32(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
2636-
"""bf16 x bf16 -> fp32 GEMM via cuBLAS. weight shape: (N, K)."""
2637-
return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight)
2638-
2639-
2640-
if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"):
2641-
2642-
@register_fake("_moe_C::router_gemm_bf16_fp32")
2643-
def router_gemm_bf16_fp32_fake(
2644-
input: torch.Tensor,
2645-
weight: torch.Tensor,
2646-
) -> torch.Tensor:
2647-
return torch.empty(input.shape[0], weight.shape[0], dtype=torch.float32, device=input.device)
2648-
2649-
26502635
def dsv3_router_gemm(
26512636
hidden_states: torch.Tensor,
26522637
router_weight: torch.Tensor,
@@ -3552,6 +3537,9 @@ def cpu_attn_reshape_and_cache(
35523537
value_cache: torch.Tensor,
35533538
slot_mapping: torch.Tensor,
35543539
isa: str,
3540+
k_scale: float = 1.0,
3541+
v_scale: float = 1.0,
3542+
kv_cache_dtype: str = "auto",
35553543
) -> None:
35563544
torch.ops._C.cpu_attn_reshape_and_cache(
35573545
key,
@@ -3560,6 +3548,9 @@ def cpu_attn_reshape_and_cache(
35603548
value_cache,
35613549
slot_mapping,
35623550
isa,
3551+
k_scale,
3552+
v_scale,
3553+
kv_cache_dtype,
35633554
)
35643555

35653556

@@ -3578,6 +3569,9 @@ def cpu_attention_with_kv_cache(
35783569
softcap: float,
35793570
scheduler_metadata: torch.Tensor,
35803571
s_aux: torch.Tensor | None,
3572+
k_scale: float = 1.0,
3573+
v_scale: float = 1.0,
3574+
kv_cache_dtype: str = "auto",
35813575
) -> None:
35823576
torch.ops._C.cpu_attention_with_kv_cache(
35833577
query,
@@ -3595,6 +3589,9 @@ def cpu_attention_with_kv_cache(
35953589
softcap,
35963590
scheduler_metadata,
35973591
s_aux,
3592+
k_scale,
3593+
v_scale,
3594+
kv_cache_dtype,
35983595
)
35993596

36003597

aphrodite/compilation/backends.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,7 @@ def compile(
265265
compilation_counter.num_backend_compilations += 1
266266

267267
compiled_graph = None
268+
handle = None
268269

269270
# try to load from the cache
270271
compiled_graph = self.load(graph, example_inputs, graph_index, compile_range)
@@ -342,7 +343,7 @@ def autograd_cache_key(*args, **kwargs):
342343
)
343344
except StopCompiling:
344345
assert cache_key is not None
345-
return self.loaded_artifacts[cache_key]
346+
compiled_graph = self.loaded_artifacts[cache_key]
346347
if cache_key is not None and compiled_graph is not None:
347348
self.loaded_artifacts[cache_key] = compiled_graph
348349

aphrodite/compilation/cuda_graph.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
268268
# across layers will make the cudagraph capture very slow.
269269
# therefore, we only run gc for the first graph,
270270
# and disable gc for the rest of the graphs.
271-
stack.enter_context(patch("gc.collect", lambda: None))
272-
stack.enter_context(patch("torch.accelerator.empty_cache", lambda: None))
271+
stack.enter_context(patch("gc.collect", lambda *args, **kwargs: None))
272+
stack.enter_context(
273+
patch(
274+
"torch.accelerator.empty_cache",
275+
lambda *args, **kwargs: None,
276+
)
277+
)
273278

274279
if self.graph_pool is not None:
275280
set_graph_pool_id(self.graph_pool)

0 commit comments

Comments
 (0)