22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33import functools
44from collections .abc import Callable
5+ from contextlib import contextmanager
6+ from typing import Protocol
57
68import torch
79from torch ._ops import OpOverload
10+ from torch .distributed import ProcessGroup
811
912import aphrodite .envs as envs
1013from aphrodite .platforms import current_platform
@@ -39,6 +42,27 @@ def is_aiter_found() -> bool:
3942IS_AITER_FOUND = is_aiter_found ()
4043
4144
45+ class AiterCustomAllreduceProto (Protocol ):
46+ max_size : int
47+ world_size : int
48+ fully_connected : bool
49+
50+ @contextmanager
51+ def capture (self ): ...
52+ def close (self ) -> None : ...
53+ def fused_ar_rms (
54+ self ,
55+ inp : torch .Tensor ,
56+ res_inp : torch .Tensor ,
57+ * ,
58+ w : torch .Tensor ,
59+ eps : float ,
60+ registered : bool = False ,
61+ use_1stage : bool = False ,
62+ ) -> tuple [torch .Tensor , torch .Tensor ]: ...
63+ def should_custom_ar (self , inp : torch .Tensor ) -> bool : ...
64+
65+
4266def is_aiter_found_and_supported () -> bool :
4367 """Check if AITER library is available and platform supports it.
4468
@@ -731,6 +755,55 @@ def _rocm_aiter_per_tensor_quant_impl(
731755 return per_tensor_quant_hip (x , scale , quant_dtype )
732756
733757
758+ def _rocm_aiter_fused_allreduce_rmsnorm_impl (
759+ input_ : torch .Tensor ,
760+ residual : torch .Tensor ,
761+ weight : torch .Tensor ,
762+ epsilon : float ,
763+ ) -> tuple [torch .Tensor , torch .Tensor ]:
764+ aiter_ar = rocm_aiter_ops .get_aiter_allreduce ()
765+ assert aiter_ar is not None , "aiter allreduce must be initialized"
766+
767+ total_bytes = input_ .numel () * input_ .element_size ()
768+ hidden_dim = input_ .shape [- 1 ]
769+ token_num = input_ .shape [0 ]
770+ hidden_ok = hidden_dim in (512 , 1024 , 2048 , 4096 , 7168 )
771+ token_ok = token_num <= 80
772+ world_size = aiter_ar .world_size
773+ full_nvlink = aiter_ar .fully_connected
774+
775+ if world_size == 2 :
776+ size_ok = True
777+ elif full_nvlink and world_size <= 4 :
778+ size_ok = total_bytes < 256 * 1024
779+ elif full_nvlink and world_size <= 8 :
780+ size_ok = total_bytes < 128 * 1024
781+ else :
782+ size_ok = False
783+
784+ use_1stage = hidden_ok and token_ok and size_ok
785+
786+ result = aiter_ar .fused_ar_rms (
787+ input_ ,
788+ residual ,
789+ w = weight ,
790+ eps = epsilon ,
791+ registered = torch .cuda .is_current_stream_capturing (),
792+ use_1stage = use_1stage ,
793+ )
794+ assert result is not None
795+ return result [0 ], result [1 ]
796+
797+
798+ def _rocm_aiter_fused_allreduce_rmsnorm_fake (
799+ input_ : torch .Tensor ,
800+ residual : torch .Tensor ,
801+ weight : torch .Tensor ,
802+ epsilon : float ,
803+ ) -> tuple [torch .Tensor , torch .Tensor ]:
804+ return torch .empty_like (input_ ), torch .empty_like (residual )
805+
806+
734807def _rocm_aiter_per_tensor_quant_fake (
735808 x : torch .Tensor ,
736809 quant_dtype : torch .dtype ,
@@ -747,7 +820,7 @@ def _rocm_aiter_per_token_quant_impl(
747820 assert quant_dtype in [torch .int8 , FP8_DTYPE ]
748821
749822 out_shape = x .shape
750- out = torch .empty (x .shape , dtype = FP8_DTYPE , device = x .device )
823+ out = torch .empty (x .shape , dtype = quant_dtype , device = x .device )
751824 if scale is None :
752825 scale = torch .empty ((* out_shape [:- 1 ], 1 ), dtype = torch .float32 , device = x .device )
753826 dynamic_per_token_scaled_quant (
@@ -767,7 +840,7 @@ def _rocm_aiter_per_token_quant_fake(
767840) -> tuple [torch .Tensor , torch .Tensor ]:
768841 out_shape = x .shape
769842 return (
770- torch .empty (x .shape , dtype = FP8_DTYPE , device = x .device ),
843+ torch .empty (x .shape , dtype = quant_dtype , device = x .device ),
771844 torch .empty ((* out_shape [:- 1 ], 1 ), dtype = torch .float32 , device = x .device ),
772845 )
773846
@@ -1157,6 +1230,9 @@ class rocm_aiter_ops:
11571230 # TODO: Consolidate under _LINEAR_ENABLED
11581231 _TRITON_UNQUANT_GEMM = envs .APHRODITE_ROCM_USE_AITER_TRITON_GEMM
11591232
1233+ _ALL_REDUCE_MAX_SIZE : int = 8192 * 1024 * 8 * 2
1234+ _CUSTOM_ALL_REDUCE : AiterCustomAllreduceProto | None = None
1235+
11601236 @classmethod
11611237 def refresh_env_variables (cls ):
11621238 """
@@ -1324,6 +1400,40 @@ def is_triton_rotary_embed_enabled(cls) -> bool:
13241400 def is_triton_gemm_enabled (cls ) -> bool :
13251401 return cls ._AITER_ENABLED and cls ._TRITON_UNQUANT_GEMM
13261402
1403+ @classmethod
1404+ @if_aiter_supported
1405+ def is_tgemm_enabled (cls ) -> bool :
1406+ from aphrodite .platforms .rocm import on_gfx950
1407+
1408+ return cls .is_linear_enabled () and on_gfx950 ()
1409+
1410+ @classmethod
1411+ def initialize_aiter_allreduce (cls , group : ProcessGroup , device : torch .device ) -> None :
1412+ try :
1413+ from aiter .dist .device_communicators .custom_all_reduce import (
1414+ CustomAllreduce as AiterCustomAllreduce ,
1415+ )
1416+
1417+ cls ._CUSTOM_ALL_REDUCE = AiterCustomAllreduce (group , device )
1418+ except Exception :
1419+ cls ._CUSTOM_ALL_REDUCE = None
1420+
1421+ @classmethod
1422+ def get_aiter_allreduce (cls ) -> AiterCustomAllreduceProto | None :
1423+ return cls ._CUSTOM_ALL_REDUCE
1424+
1425+ @classmethod
1426+ def destroy_aiter_allreduce (cls ) -> None :
1427+ if cls ._CUSTOM_ALL_REDUCE is not None :
1428+ cls ._CUSTOM_ALL_REDUCE .close ()
1429+ cls ._CUSTOM_ALL_REDUCE = None
1430+
1431+ @classmethod
1432+ def get_aiter_allreduce_max_size (cls ) -> int | None :
1433+ # effective max input size (based on upstream aiter version: v0.1.10.post3)
1434+ # https://github.com/ROCm/aiter/blob/6a0e7b26ccf33164785531212cc2ec2cde0b9243/aiter/dist/device_communicators/custom_all_reduce.py#L272-L273
1435+ return int (cls ._ALL_REDUCE_MAX_SIZE / 2 )
1436+
13271437 @staticmethod
13281438 @if_aiter_supported
13291439 def register_ops_once () -> None :
@@ -1514,6 +1624,12 @@ def register_ops_once() -> None:
15141624 fake_impl = _triton_rotary_embedding_fake ,
15151625 )
15161626
1627+ direct_register_custom_op (
1628+ op_name = "rocm_aiter_fused_allreduce_rmsnorm" ,
1629+ op_func = _rocm_aiter_fused_allreduce_rmsnorm_impl ,
1630+ fake_impl = _rocm_aiter_fused_allreduce_rmsnorm_fake ,
1631+ )
1632+
15171633 direct_register_custom_op (
15181634 op_name = "fused_mla_dual_rms_norm" ,
15191635 op_func = _fused_mla_dual_rms_norm_impl ,
@@ -1567,6 +1683,10 @@ def get_triton_add_rmsnorm_pad_op() -> OpOverload:
15671683 def get_triton_rotary_embedding_op () -> OpOverload :
15681684 return torch .ops .aphrodite .rocm_aiter_triton_rotary_embedding .default
15691685
1686+ @staticmethod
1687+ def get_fused_allreduce_rmsnorm_op () -> OpOverload :
1688+ return torch .ops .aphrodite .rocm_aiter_fused_allreduce_rmsnorm .default
1689+
15701690 @staticmethod
15711691 def get_fused_mla_dual_rms_norm_op () -> OpOverload :
15721692 return torch .ops .aphrodite .fused_mla_dual_rms_norm .default
0 commit comments