Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions megatron/core/models/gpt/fine_grained_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def should_free_input(name, is_moe, config, num_local_experts):
return False
enable_deepep = (
config.moe_token_dispatcher_type == "flex"
and config.moe_flex_dispatcher_backend == "deepep"
and config.moe_flex_dispatcher_backend in ("deepep", "deepepv2")
)
enable_hybridep = (
config.moe_token_dispatcher_type == "flex"
Expand Down Expand Up @@ -423,7 +423,7 @@ def build_transformer_layer_callables(layer: TransformerLayer):
is_moe = isinstance(layer.mlp, MoELayer)
enable_deepep = (
layer.config.moe_token_dispatcher_type == "flex"
and layer.config.moe_flex_dispatcher_backend == "deepep"
and layer.config.moe_flex_dispatcher_backend in ("deepep", "deepepv2")
)
enable_hybridep = (
layer.config.moe_token_dispatcher_type == "flex"
Expand Down
206 changes: 205 additions & 1 deletion megatron/core/transformer/moe/fused_a2a.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,32 @@
from megatron.core.utils import internal_api

try:
from deep_ep import Buffer
from deep_ep.utils import EventHandle, EventOverlap
except ImportError:
try:
from deep_ep import EventHandle, EventOverlap
except ImportError:
EventHandle = None
EventOverlap = None

try:
from deep_ep import Buffer

HAVE_DEEP_EP = True
except ImportError:
HAVE_DEEP_EP = False

try:
from deep_ep import ElasticBuffer

HAVE_DEEP_EP_V2 = True
except ImportError:
HAVE_DEEP_EP_V2 = False

import torch

_buffer = None
_elastic_buffer = None


def get_hidden_bytes(x: torch.Tensor) -> int:
Expand Down Expand Up @@ -68,6 +84,39 @@ def get_buffer(group: torch.distributed.ProcessGroup, hidden_bytes: int):
return _buffer


def get_elastic_buffer(
group: torch.distributed.ProcessGroup, num_max_tokens_per_rank: int, hidden: int, num_topk: int
):
"""Get or create a DeepEP v2 elastic buffer for all-to-all communication."""
global _elastic_buffer

num_bytes = ElasticBuffer.get_buffer_size_hint(
group, num_max_tokens_per_rank=num_max_tokens_per_rank, hidden=hidden, num_topk=num_topk
)

if (
_elastic_buffer is None
or _elastic_buffer.group != group
or _elastic_buffer.num_bytes < num_bytes
or _elastic_buffer.num_max_tokens_per_rank < num_max_tokens_per_rank
):
_elastic_buffer = ElasticBuffer(
group,
num_bytes=num_bytes,
num_max_tokens_per_rank=num_max_tokens_per_rank,
hidden=hidden,
num_topk=num_topk,
)
return _elastic_buffer


def _capture_elastic_previous_event(buffer, async_finish: bool, allocate_on_comm_stream: bool):
"""Capture the current stream for DeepEP v2 async communication."""
if async_finish and allocate_on_comm_stream:
return buffer.capture()
return None


class FusedDispatch(torch.autograd.Function):
"""Fused dispatch operation for MoE routing combining computation and communication."""

Expand Down Expand Up @@ -267,6 +316,161 @@ def set_deepep_num_sms(num_sms):
set_deepep_num_sms = None


class DeepepV2Dispatch(torch.autograd.Function):
"""DeepEP v2 elastic dispatch with autograd support."""

@staticmethod
def forward(
ctx,
buffer,
x,
token_indices,
token_probs,
num_experts,
num_max_tokens_per_rank,
expert_alignment,
num_sms,
async_finish=False,
allocate_on_comm_stream=False,
):
"""Forward pass of DeepEP v2 elastic dispatch."""
previous_event = _capture_elastic_previous_event(
buffer, async_finish, allocate_on_comm_stream
)
recv_x, recv_token_indices, recv_token_probs, handle, event = buffer.dispatch(
x,
topk_idx=token_indices,
topk_weights=token_probs,
num_experts=num_experts,
num_max_tokens_per_rank=num_max_tokens_per_rank,
expert_alignment=expert_alignment,
num_sms=num_sms,
previous_event=previous_event,
async_with_compute_stream=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)

if async_finish:
event.current_stream_wait()

ctx.buffer = buffer
ctx.handle = handle
ctx.num_sms = num_sms
ctx.async_finish = async_finish
ctx.allocate_on_comm_stream = allocate_on_comm_stream
tokens_per_expert = torch.tensor(handle.num_recv_tokens_per_expert_list)

return (recv_x, recv_token_indices, recv_token_probs, tokens_per_expert, handle)

@staticmethod
def backward(
ctx, grad_output, grad_token_indices, grad_token_probs, grad_tokens_per_expert, grad_handle
):
"""Backward pass of DeepEP v2 elastic dispatch."""
previous_event = _capture_elastic_previous_event(
ctx.buffer, ctx.async_finish, ctx.allocate_on_comm_stream
)
grad_x, grad_token_probs, event = ctx.buffer.combine(
grad_output.contiguous(),
handle=ctx.handle,
topk_weights=grad_token_probs.float(),
num_sms=ctx.num_sms,
previous_event=previous_event,
async_with_compute_stream=ctx.async_finish,
allocate_on_comm_stream=ctx.allocate_on_comm_stream,
)
if ctx.async_finish:
event.current_stream_wait()
return None, grad_x, None, grad_token_probs, None, None, None, None, None, None


class DeepepV2Combine(torch.autograd.Function):
"""DeepEP v2 elastic combine with autograd support."""

@staticmethod
def forward(ctx, buffer, x, handle, num_sms, async_finish=False, allocate_on_comm_stream=False):
"""Forward pass of DeepEP v2 elastic combine."""
previous_event = _capture_elastic_previous_event(
buffer, async_finish, allocate_on_comm_stream
)
combined_x, combined_token_probs, event = buffer.combine(
x,
handle=handle,
num_sms=num_sms,
previous_event=previous_event,
async_with_compute_stream=async_finish,
allocate_on_comm_stream=allocate_on_comm_stream,
)
if async_finish:
event.current_stream_wait()

ctx.buffer = buffer
ctx.handle = handle
ctx.num_sms = num_sms
ctx.async_finish = async_finish
ctx.allocate_on_comm_stream = allocate_on_comm_stream
return combined_x, combined_token_probs

@staticmethod
def backward(ctx, grad_output, grad_combined_token_probs):
"""Backward pass of DeepEP v2 elastic combine."""
previous_event = _capture_elastic_previous_event(
ctx.buffer, ctx.async_finish, ctx.allocate_on_comm_stream
)
grad_x, _, _, _, event = ctx.buffer.dispatch(
grad_output.contiguous(),
handle=ctx.handle,
num_sms=ctx.num_sms,
previous_event=previous_event,
async_with_compute_stream=ctx.async_finish,
allocate_on_comm_stream=ctx.allocate_on_comm_stream,
)
if ctx.async_finish:
event.current_stream_wait()
return None, grad_x, None, None, None, None, None


if HAVE_DEEP_EP_V2:

def deepepv2_dispatch(
buffer,
x,
token_indices,
token_probs,
num_experts,
num_max_tokens_per_rank,
expert_alignment=1,
num_sms=0,
async_finish=False,
allocate_on_comm_stream=False,
):
"""Perform DeepEP v2 elastic dispatch."""
return DeepepV2Dispatch.apply(
buffer,
x.contiguous(),
token_indices,
token_probs,
num_experts,
num_max_tokens_per_rank,
expert_alignment,
num_sms,
async_finish,
allocate_on_comm_stream,
)

def deepepv2_combine(
buffer, x, handle, num_sms=0, async_finish=False, allocate_on_comm_stream=False
):
"""Perform DeepEP v2 elastic combine."""
return DeepepV2Combine.apply(
buffer, x.contiguous(), handle, num_sms, async_finish, allocate_on_comm_stream
)

else:
deepepv2_dispatch = None
deepepv2_combine = None


try:
from deep_ep import HybridEPBuffer

Expand Down
Loading