Skip to content

Commit 4aee0ba

Browse files
author
niushengxiao
committed
fix: fix ep v2 for sm100
1 parent feac819 commit 4aee0ba

12 files changed

Lines changed: 123 additions & 26 deletions

File tree

lightllm/common/basemodel/attention/fa4/fp.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ class Fa4PrefillAttState(PagedFa3PrefillAttState):
5959
def _normal_prefill_att(
6060
self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, att_control: AttControl, alloc_func=torch.empty
6161
) -> torch.Tensor:
62+
import triton
63+
6264
if att_control.use_sliding_window:
6365
window_size = att_control.sliding_window
6466
else:
@@ -81,7 +83,7 @@ def _normal_prefill_att(
8183
cu_seqlens_q=self.cu_seqlens_q,
8284
seqused_k=self.infer_state.b_seq_len.int(),
8385
max_seqlen_q=self.infer_state.max_q_seq_len,
84-
max_seqlen_k=self.infer_state.max_kv_seq_len,
86+
max_seqlen_k=triton.cdiv(self.infer_state.max_kv_seq_len, self.backend.page_size) * self.backend.page_size,
8587
page_table=self.page_table,
8688
softmax_scale=softmax_scale,
8789
causal=True,

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def experts(
154154
topk_group: int,
155155
num_expert_group: int,
156156
is_prefill: Optional[bool] = None,
157+
is_cuda_graph: bool = False,
157158
) -> torch.Tensor:
158159
"""Backward compatible method that routes to platform-specific implementation."""
159160
return self.fuse_moe_impl(
@@ -169,6 +170,7 @@ def experts(
169170
topk_group=topk_group,
170171
num_expert_group=num_expert_group,
171172
is_prefill=is_prefill,
173+
is_cuda_graph=is_cuda_graph,
172174
)
173175

174176
def use_sm100_mega_moe(self) -> bool:

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,54 @@ def _mega_moe(
9999
)
100100
return output
101101

102+
def _sm100_fp4_cuda_graph_moe(
103+
self,
104+
hidden_states: torch.Tensor,
105+
w13: WeightPack,
106+
w2: WeightPack,
107+
topk_weights: torch.Tensor,
108+
topk_ids: torch.Tensor,
109+
) -> torch.Tensor:
110+
from deep_gemm.utils import per_token_cast_to_fp8
111+
112+
buffer = getattr(dist_group_manager, "ep_buffer", None)
113+
if buffer is None:
114+
raise RuntimeError("SM100 CUDA graph MoE fallback requires dist_group_manager.ep_buffer")
115+
116+
num_max_tokens_per_rank = get_deepep_num_max_dispatch_tokens_per_rank_decode()
117+
qinput_tensor = per_token_cast_to_fp8(
118+
hidden_states,
119+
use_ue8m0=True,
120+
gran_k=self.quant_method.block_size,
121+
use_packed_ue8m0=True,
122+
)
123+
alignment = getattr(dist_group_manager, "ep_expert_alignment", 128)
124+
cumulative_stats = self._get_mega_moe_stats(w13.weight.shape[0], hidden_states.device)
125+
recv_x, recv_topk_idx, recv_topk_weights, handle, _ = buffer.dispatch(
126+
qinput_tensor,
127+
topk_idx=topk_ids,
128+
topk_weights=topk_weights,
129+
cumulative_local_expert_recv_stats=cumulative_stats,
130+
num_experts=self.total_expert_num_contain_redundancy,
131+
num_max_tokens_per_rank=num_max_tokens_per_rank,
132+
expert_alignment=alignment,
133+
do_cpu_sync=False,
134+
do_handle_copy=False,
135+
do_expand=True,
136+
use_tma_aligned_col_major_sf=True,
137+
)
138+
gemm_out = self.prefilled_group_gemm(
139+
handle.psum_num_recv_tokens_per_expert,
140+
recv_x,
141+
recv_topk_idx,
142+
recv_topk_weights,
143+
w13,
144+
w2,
145+
hidden_states.dtype,
146+
)
147+
combined_x, _, _ = buffer.combine(gemm_out, handle=handle, topk_weights=None)
148+
return combined_x
149+
102150
def _select_experts(
103151
self,
104152
input_tensor: torch.Tensor,
@@ -147,11 +195,17 @@ def _fused_experts(
147195
topk_ids: torch.Tensor,
148196
router_logits: Optional[torch.Tensor] = None,
149197
is_prefill: Optional[bool] = None,
198+
is_cuda_graph: bool = False,
150199
):
151200

152201
w13_weight, w13_scale = w13.weight, w13.weight_scale
153202
w2_weight, w2_scale = w2.weight, w2.weight_scale
154203
if self._use_sm100_fp4_moe():
204+
# DeepGEMM's official Mega MoE example is an eager fused path. For
205+
# decode CUDA graph, use the official ElasticBuffer + grouped GEMM
206+
# baseline instead of capturing Mega MoE's NVLink barrier kernel.
207+
if is_cuda_graph and not is_prefill:
208+
return self._sm100_fp4_cuda_graph_moe(input_tensor, w13, w2, topk_weights, topk_ids.to(torch.long))
155209
return self._mega_moe(input_tensor, w13, w2, topk_weights, topk_ids.to(torch.long))
156210

157211
use_fp8_w8a8 = self.quant_method.method_name != "none"

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def _fused_experts(
2929
topk_ids: torch.Tensor,
3030
router_logits: Optional[torch.Tensor] = None,
3131
is_prefill: Optional[bool] = None,
32+
is_cuda_graph: bool = False,
3233
):
3334

3435
w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def _fused_experts(
9191
topk_ids: torch.Tensor,
9292
router_logits: Optional[torch.Tensor] = None,
9393
is_prefill: bool = False,
94+
is_cuda_graph: bool = False,
9495
):
9596
w13_weight, w13_scale = w13.weight, w13.weight_scale
9697
w2_weight, w2_scale = w2.weight, w2.weight_scale
@@ -125,6 +126,7 @@ def __call__(
125126
topk_group: int,
126127
num_expert_group: int,
127128
is_prefill: Optional[bool] = None,
129+
is_cuda_graph: bool = False,
128130
):
129131
topk_weights, topk_ids = self._select_experts(
130132
input_tensor=input_tensor,
@@ -145,5 +147,6 @@ def __call__(
145147
topk_ids=topk_ids,
146148
router_logits=router_logits,
147149
is_prefill=is_prefill,
150+
is_cuda_graph=is_cuda_graph,
148151
)
149152
return output

lightllm/common/quantization/deepgemm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class DeepGEMMFP8FP4B32QuantizationMethod(DeepGEMMBaseQuantizationMethod):
131131
def __init__(self):
132132
super().__init__()
133133
self.block_size = 32
134+
self.ue8m0_pack_factor = 4
134135
self.weight_suffix = "weight"
135136
self.weight_zero_point_suffix = None
136137
self.weight_scale_suffix = None
@@ -179,14 +180,26 @@ def apply(
179180
def _create_weight(
180181
self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1
181182
) -> Tuple[WeightPack, List[WeightPack]]:
183+
import deep_gemm
184+
182185
out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims
183186
assert in_dim % 2 == 0, "FP4 packed weight requires even input dimension"
184187
assert in_dim % self.block_size == 0, "FP4 scale dimension must be divisible by block_size"
188+
assert (
189+
in_dim % (self.block_size * self.ue8m0_pack_factor) == 0
190+
), "SM100 FP4 scale layout requires input dimension divisible by 128"
185191
expert_prefix = (num_experts,) if num_experts > 1 else ()
186192
weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8).cuda(device_id)
187-
weight_scale = torch.empty(expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.int32).cuda(
193+
raw_weight_scale = torch.empty(expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.float32).cuda(
188194
device_id
189195
)
196+
weight_scale = deep_gemm.transform_sf_into_required_layout(
197+
raw_weight_scale,
198+
out_dim,
199+
in_dim,
200+
(1, self.block_size),
201+
num_experts if num_experts > 1 else None,
202+
)
190203
mm_param = WeightPack(weight=weight, weight_scale=weight_scale)
191204
mm_param_list = self._split_weight_pack(
192205
mm_param,

lightllm/distributed/communication_op.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(self):
111111
self.ep_low_latency_buffer = None
112112
self.ep_mega_moe_buffer = None
113113
self.ep_num_sms = None
114+
self.ep_expert_alignment = 128
114115

115116
def __len__(self):
116117
return len(self.groups)
@@ -156,33 +157,26 @@ def new_deepep_group(
156157
self.ll_decode_num_tokens = decode_num_max_dispatch_tokens_per_rank
157158
self.ll_hidden = hidden_size
158159
self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size
159-
self.ep_buffer = deep_ep.ElasticBuffer(
160-
deepep_group,
161-
num_max_tokens_per_rank=self.ll_num_tokens,
162-
hidden=self.ll_hidden,
163-
num_topk=num_experts_per_tok,
164-
use_fp8_dispatch=True,
165-
allow_multiple_reduction=False,
166-
)
167-
self.ep_mega_moe_buffer = None
168160
self.ep_low_latency_buffer = None
169-
if not is_sm100_gpu():
170-
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
171-
self.ll_decode_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts
172-
)
173-
self.ep_low_latency_buffer = deep_ep.Buffer(
174-
deepep_group,
175-
int(1e9),
176-
num_rdma_bytes,
177-
low_latency_mode=True,
178-
num_qps_per_rank=(self.ll_num_experts // global_world_size),
179-
)
180-
else:
161+
self.ep_mega_moe_buffer = None
162+
if is_sm100_gpu():
181163
if moe_intermediate_size is None:
182164
raise ValueError("SM100 Mega MoE requires moe_intermediate_size or intermediate_size in model config")
183165

184166
import deep_gemm
185167

168+
self.ep_expert_alignment = deep_gemm.get_theoretical_mk_alignment_for_contiguous_layout()
169+
deep_gemm.set_mk_alignment_for_contiguous_layout(self.ep_expert_alignment)
170+
# Mega MoE is the eager fast path, while ElasticBuffer provides the official
171+
# CUDA-graph-compatible baseline for decode.
172+
self.ep_buffer = deep_ep.ElasticBuffer(
173+
deepep_group,
174+
num_max_tokens_per_rank=self.ll_decode_num_tokens,
175+
hidden=self.ll_hidden,
176+
num_topk=num_experts_per_tok,
177+
use_fp8_dispatch=True,
178+
allow_multiple_reduction=False,
179+
)
186180
self.ep_mega_moe_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
187181
deepep_group,
188182
self.ll_num_experts,
@@ -191,6 +185,28 @@ def new_deepep_group(
191185
self.ll_hidden,
192186
moe_intermediate_size,
193187
)
188+
self._set_num_sms_for_deep_gemm(0)
189+
logger.info("SM100 detected: use Mega MoE for eager path and ElasticBuffer for CUDA graph decode.")
190+
return
191+
192+
self.ep_buffer = deep_ep.ElasticBuffer(
193+
deepep_group,
194+
num_max_tokens_per_rank=self.ll_num_tokens,
195+
hidden=self.ll_hidden,
196+
num_topk=num_experts_per_tok,
197+
use_fp8_dispatch=True,
198+
allow_multiple_reduction=False,
199+
)
200+
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
201+
self.ll_decode_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts
202+
)
203+
self.ep_low_latency_buffer = deep_ep.Buffer(
204+
deepep_group,
205+
int(1e9),
206+
num_rdma_bytes,
207+
low_latency_mode=True,
208+
num_qps_per_rank=(self.ll_num_experts // global_world_size),
209+
)
194210
theoretical_sms = self.ep_buffer.get_theoretical_num_sms(self.ll_num_experts, num_experts_per_tok)
195211
self._set_num_sms_for_deep_gemm(theoretical_sms)
196212

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def _moe_ffn_edp(
258258
topk_group=self.topk_group,
259259
num_expert_group=self.n_group,
260260
is_prefill=infer_state.is_prefill,
261+
is_cuda_graph=infer_state.is_cuda_graph,
261262
)
262263

263264
if self.n_shared_experts is not None:

lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def _moe_ffn_edp(
104104
topk_group=None,
105105
num_expert_group=None,
106106
is_prefill=infer_state.is_prefill,
107+
is_cuda_graph=infer_state.is_cuda_graph,
107108
)
108109

109110
ep_output = ep_output.view(token_num, hidden_dim)

lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def _moe_ffn_edp(
156156
topk_group=None,
157157
num_expert_group=None,
158158
is_prefill=infer_state.is_prefill,
159+
is_cuda_graph=infer_state.is_cuda_graph,
159160
)
160161
ep_output = ep_output.view(token_num, hidden_dim)
161162
ep_output.add_(shared_expert_out)

0 commit comments

Comments
 (0)