Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions aphrodite/distributed/device_communicators/all2all.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
from typing import TYPE_CHECKING, Any
from typing import Any

import torch
import torch.distributed as dist
from loguru import logger

from aphrodite.utils import has_deep_ep, has_pplx
from aphrodite.forward_context import get_forward_context
from aphrodite.utils import has_deep_ep, has_pplx

from .base_device_communicator import All2AllManagerBase, Cache

if TYPE_CHECKING:
from aphrodite.modeling.layers.fused_moe.layer import FusedMoE
else:
FusedMoE = None


class NaiveAll2AllManager(All2AllManagerBase):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,10 @@ def prepare_communication_buffer_for_model(self,

moe_modules = [
module for module in model.modules()
if module.__class__.__name__ == "FusedMoE"
# TODO: Should use isinstance but can't. Maybe search for
# presence of quant_method.init_prepare_finalize?
if (module.__class__.__name__ == "FusedMoE"
or module.__class__.__name__ == "SharedFusedMoE")
Comment on lines +255 to +256
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using __class__.__name__ for checking the module type is brittle and can break with subclassing or refactoring. A more robust approach would be to check for the presence of specific attributes, as suggested in the TODO comment. This avoids issues with subclassing or renaming.

            if hasattr(module, "quant_method") and hasattr(
                module.quant_method, "init_prepare_finalize"
            )

]
for module in moe_modules:
module.quant_method.init_prepare_finalize(module)
Expand Down
129 changes: 90 additions & 39 deletions aphrodite/modeling/layers/fused_moe/deepep_ht_prepare_finalize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional
from typing import Callable, Optional, Union

import deep_ep
import torch
Expand All @@ -22,6 +22,7 @@ def __init__(self, buffer: deep_ep.Buffer, num_dispatchers: int,
self.num_dispatchers_ = num_dispatchers
self.dp_size = dp_size
self.rank_expert_offset = rank_expert_offset
self.async_prepare = True
# The dispatch function returns a handle that the combine function
# requires. We store the handle here so it is available to the
# combine function.
Expand Down Expand Up @@ -53,10 +54,16 @@ def _get_combine_config(self) -> Optional[deep_ep.Config]:
return None
return deep_ep.Buffer.get_combine_config(self.dp_size)

def _do_dispatch(self, tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor, num_experts: int):
def _do_dispatch(
self,
tokens: torch.Tensor,
token_scales: Optional[torch.Tensor],
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor,
num_experts: int,
a1_scale: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
) -> Callable:

has_scales = token_scales is not None

Expand Down Expand Up @@ -90,9 +97,36 @@ def _do_dispatch(self, tokens: torch.Tensor,
expert_alignment=1,
config=self._get_dispatch_config(),
previous_event=None,
async_finish=False,
async_finish=self.async_prepare,
allocate_on_comm_stream=False)

return lambda: self._receiver(
event,
has_scales,
token_data,
expert_topk_ids,
num_experts,
expert_num_tokens_per_expert_list,
expert_topk_weights,
a1_scale,
quant_config,
)

def _receiver(
self,
event: deep_ep.EventOverlap,
has_scales: bool,
token_data: Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor],
expert_topk_ids: Optional[torch.Tensor],
num_experts: int,
expert_num_tokens_per_expert_list: list[int],
expert_topk_weights: Optional[torch.Tensor],
a1_scale: Optional[torch.Tensor],
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if self.async_prepare:
event.current_stream_wait()

if has_scales:
expert_x, expert_x_scale = token_data
else:
Expand All @@ -109,6 +143,7 @@ def _do_dispatch(self, tokens: torch.Tensor,
# DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces.
assert expert_topk_ids is not None
expert_topk_ids = torch.where(
expert_topk_ids == -1,
num_experts - 1 if self.rank_expert_offset == 0 else 0,
Expand All @@ -120,10 +155,28 @@ def _do_dispatch(self, tokens: torch.Tensor,
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device)

# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16 and quantize afterwards
if not quant_config.is_block_quantized:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)

return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)

def prepare(
def supports_async(self) -> bool:
return True

def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
Expand All @@ -134,9 +187,7 @@ def prepare(
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> Callable:

if apply_router_weight_on_input:
topk = topk_ids.size(1)
Expand All @@ -156,37 +207,37 @@ def prepare(
)
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
(expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
a1_post_scale = None
else:
# Dispatch and Quant
# DeepEP kernels only support dispatching block-quantized
# activation scales.
# Dispatch in bfloat16
(expert_x, _, expert_tokens_meta, expert_topk_ids,
expert_topk_weights) = self._do_dispatch(
tokens=a1,
token_scales=None,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts)
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape)
a1q = a1
a1q_scale = None
a1_post_scale = a1_scale

return (expert_x, expert_x_scale, expert_tokens_meta, expert_topk_ids,
expert_topk_weights)
return self._do_dispatch(tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config)

def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()

def finalize(
self,
Expand Down
50 changes: 41 additions & 9 deletions aphrodite/modeling/layers/fused_moe/deepep_ll_prepare_finalize.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union
from typing import Callable, Optional, Union

import deep_ep
import torch
Expand Down Expand Up @@ -73,7 +73,6 @@ def _do_quant(
self,
x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
a1_dtype: torch.dtype,
quant_dtype: Union[torch.dtype, str, None],
per_act_token_quant: bool,
Expand Down Expand Up @@ -108,7 +107,10 @@ def _do_quant(

return x, x_scales

def prepare(
def supports_async(self) -> bool:
return True

def prepare_async(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
Expand All @@ -119,9 +121,7 @@ def prepare(
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> mk.ReceiverType:

hidden_size = a1.size(1)
assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \
Expand Down Expand Up @@ -153,16 +153,48 @@ def prepare(
num_experts,
use_fp8=self.use_fp8_dispatch,
async_finish=False,
return_recv_hook=False)
return_recv_hook=True)

return lambda: self._receiver(hook, expert_x, expert_num_tokens,
a1_scale, a1.dtype, quant_config)

def _receiver(
self,
hook: Callable,
expert_x: Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
expert_num_tokens: torch.Tensor,
a1_scale,
a1_dtype,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
hook()

expert_x, expert_x_scale = self._do_quant(
expert_x, a1_scale, a2_scale, a1.dtype, quant_config.quant_dtype,
expert_x, a1_scale, a1_dtype, quant_config.quant_dtype,
quant_config.per_act_token_quant, quant_config.block_shape)

expert_tokens_meta = mk.ExpertTokensMetadata(
expert_num_tokens=expert_num_tokens, expert_num_tokens_cpu=None)

return (expert_x, expert_x_scale, expert_tokens_meta, None, None)
return expert_x, expert_x_scale, expert_tokens_meta, None, None

def prepare(
self,
a1: torch.Tensor,
a1_scale: Optional[torch.Tensor],
a2_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
receiver = self.prepare_async(a1, a1_scale, a2_scale, topk_weights,
topk_ids, num_experts, expert_map,
apply_router_weight_on_input,
quant_config)
return receiver()

def finalize(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ def prepare(
apply_router_weight_on_input: bool,
# TODO(bnell): use quant_config + scales instead of ctor args
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> mk.PrepareResultType:

if apply_router_weight_on_input:
topk = topk_ids.size(1)
Expand Down
4 changes: 1 addition & 3 deletions aphrodite/modeling/layers/fused_moe/fused_batched_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,7 @@ def prepare(
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor],
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> mk.PrepareResultType:
assert a1.dim() == 2
assert topk_ids.dim() == 2
assert topk_ids.size(0) == a1.size(0)
Expand Down
Loading
Loading