Skip to content

Commit 82f7675

Browse files
author
niushengxiao
committed
fix: fix ep v2 for sm100
1 parent feac819 commit 82f7675

3 files changed

Lines changed: 40 additions & 22 deletions

File tree

lightllm/common/quantization/deepgemm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class DeepGEMMFP8FP4B32QuantizationMethod(DeepGEMMBaseQuantizationMethod):
131131
def __init__(self):
132132
super().__init__()
133133
self.block_size = 32
134+
self.ue8m0_pack_factor = 4
134135
self.weight_suffix = "weight"
135136
self.weight_zero_point_suffix = None
136137
self.weight_scale_suffix = None
@@ -179,14 +180,26 @@ def apply(
179180
def _create_weight(
180181
self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1
181182
) -> Tuple[WeightPack, List[WeightPack]]:
183+
import deep_gemm
184+
182185
out_dim = sum(out_dims) if isinstance(out_dims, list) else out_dims
183186
assert in_dim % 2 == 0, "FP4 packed weight requires even input dimension"
184187
assert in_dim % self.block_size == 0, "FP4 scale dimension must be divisible by block_size"
188+
assert (
189+
in_dim % (self.block_size * self.ue8m0_pack_factor) == 0
190+
), "SM100 FP4 scale layout requires input dimension divisible by 128"
185191
expert_prefix = (num_experts,) if num_experts > 1 else ()
186192
weight = torch.empty(expert_prefix + (out_dim, in_dim // 2), dtype=torch.int8).cuda(device_id)
187-
weight_scale = torch.empty(expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.int32).cuda(
193+
raw_weight_scale = torch.empty(expert_prefix + (out_dim, in_dim // self.block_size), dtype=torch.float32).cuda(
188194
device_id
189195
)
196+
weight_scale = deep_gemm.transform_sf_into_required_layout(
197+
raw_weight_scale,
198+
out_dim,
199+
in_dim,
200+
(1, self.block_size),
201+
num_experts if num_experts > 1 else None,
202+
)
190203
mm_param = WeightPack(weight=weight, weight_scale=weight_scale)
191204
mm_param_list = self._split_weight_pack(
192205
mm_param,

lightllm/distributed/communication_op.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -156,28 +156,10 @@ def new_deepep_group(
156156
self.ll_decode_num_tokens = decode_num_max_dispatch_tokens_per_rank
157157
self.ll_hidden = hidden_size
158158
self.ll_num_experts = n_routed_experts + get_redundancy_expert_num() * global_world_size
159-
self.ep_buffer = deep_ep.ElasticBuffer(
160-
deepep_group,
161-
num_max_tokens_per_rank=self.ll_num_tokens,
162-
hidden=self.ll_hidden,
163-
num_topk=num_experts_per_tok,
164-
use_fp8_dispatch=True,
165-
allow_multiple_reduction=False,
166-
)
167-
self.ep_mega_moe_buffer = None
168159
self.ep_low_latency_buffer = None
169-
if not is_sm100_gpu():
170-
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
171-
self.ll_decode_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts
172-
)
173-
self.ep_low_latency_buffer = deep_ep.Buffer(
174-
deepep_group,
175-
int(1e9),
176-
num_rdma_bytes,
177-
low_latency_mode=True,
178-
num_qps_per_rank=(self.ll_num_experts // global_world_size),
179-
)
180-
else:
160+
self.ep_mega_moe_buffer = None
161+
if is_sm100_gpu():
162+
self.ep_buffer = None
181163
if moe_intermediate_size is None:
182164
raise ValueError("SM100 Mega MoE requires moe_intermediate_size or intermediate_size in model config")
183165

@@ -191,6 +173,28 @@ def new_deepep_group(
191173
self.ll_hidden,
192174
moe_intermediate_size,
193175
)
176+
self._set_num_sms_for_deep_gemm(0)
177+
logger.info("SM100 detected: skip DeepEP ElasticBuffer init and use Mega MoE buffer only.")
178+
return
179+
180+
self.ep_buffer = deep_ep.ElasticBuffer(
181+
deepep_group,
182+
num_max_tokens_per_rank=self.ll_num_tokens,
183+
hidden=self.ll_hidden,
184+
num_topk=num_experts_per_tok,
185+
use_fp8_dispatch=True,
186+
allow_multiple_reduction=False,
187+
)
188+
num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint(
189+
self.ll_decode_num_tokens, self.ll_hidden, global_world_size, self.ll_num_experts
190+
)
191+
self.ep_low_latency_buffer = deep_ep.Buffer(
192+
deepep_group,
193+
int(1e9),
194+
num_rdma_bytes,
195+
low_latency_mode=True,
196+
num_qps_per_rank=(self.ll_num_experts // global_world_size),
197+
)
194198
theoretical_sms = self.ep_buffer.get_theoretical_num_sms(self.ll_num_experts, num_experts_per_tok)
195199
self._set_num_sms_for_deep_gemm(theoretical_sms)
196200

lightllm/server/api_start.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
auto_set_max_req_total_len,
2626
)
2727
from lightllm.utils.dist_check_utils import auto_configure_allreduce_flags_from_args
28+
from lightllm.utils.device_utils import is_sm100_gpu
2829

2930
logger = init_logger(__name__)
3031

0 commit comments

Comments
 (0)