Skip to content

Commit f91690d

Browse files
author
niushengxiao
committed
fix: fix bugs
1 parent 41f3947 commit f91690d

6 files changed

Lines changed: 22 additions & 25 deletions

File tree

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _maybe_upgrade_quant_method_for_ep_moe(self, quant_method: QuantizationMetho
7676
if not self.enable_ep_moe:
7777
return quant_method
7878

79-
target_method = "deepgemm-fp8fp4-b32" if is_sm100_gpu() else "deepgemm-fp8w8a8-b128"
79+
target_method = "deepgemm-fp4fp8-b32" if is_sm100_gpu() else "deepgemm-fp8w8a8-b128"
8080
if quant_method.method_name == "none":
8181
from lightllm.common.quantization.registry import QUANTMETHODS
8282

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def _get_ep_num_sms(self) -> int:
2828
return getattr(dist_group_manager, "ep_num_sms", None) or 0
2929

3030
def _use_sm100_fp4_moe(self) -> bool:
31-
return is_sm100_gpu() and self.quant_method.method_name == "deepgemm-fp8fp4-b32"
31+
return is_sm100_gpu() and self.quant_method.method_name == "deepgemm-fp4fp8-b32"
3232

3333
def _get_mega_moe_weights(self, w13: WeightPack, w2: WeightPack):
3434
cache_key = (

lightllm/common/kv_cache_mem_manager/mem_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ def profile_size(self, mem_fraction):
6767

6868
torch.cuda.empty_cache()
6969
world_size = dist.get_world_size()
70-
71-
available_memory = get_available_gpu_memory(world_size) * mem_fraction
70+
available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction)
7271
cell_size = self.get_cell_size()
7372
self.size = int(available_memory * 1024 ** 3 / cell_size)
7473
if world_size > 1:

lightllm/common/quantization/deepgemm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def _create_weight(
126126
return mm_param, mm_param_list
127127

128128

129-
@QUANTMETHODS.register(["deepgemm-fp8fp4-b32"], platform="cuda")
129+
@QUANTMETHODS.register(["deepgemm-fp4fp8-b32"], platform="cuda")
130130
class DeepGEMMFP8FP4B32QuantizationMethod(DeepGEMMBaseQuantizationMethod):
131131
def __init__(self):
132132
super().__init__()
@@ -139,7 +139,7 @@ def __init__(self):
139139

140140
@property
141141
def method_name(self):
142-
return "deepgemm-fp8fp4-b32"
142+
return "deepgemm-fp4fp8-b32"
143143

144144
def quantize(self, weight: torch.Tensor, output: WeightPack):
145145
from deep_gemm.utils import per_token_cast_to_fp4
@@ -174,7 +174,7 @@ def apply(
174174
use_custom_tensor_mananger: bool = True,
175175
bias: Optional[torch.Tensor] = None,
176176
) -> torch.Tensor:
177-
raise NotImplementedError("deepgemm-fp8fp4-b32 is only implemented for fused MoE expert weights")
177+
raise NotImplementedError("deepgemm-fp4fp8-b32 is only implemented for fused MoE expert weights")
178178

179179
def _create_weight(
180180
self, out_dims: Union[int, List[int]], in_dim: int, dtype: torch.dtype, device_id: int, num_experts: int = 1

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,3 @@ nixl==1.1.0
9898
xformers==0.0.35
9999
redis==7.3.0
100100
litellm>=1.52.0,<1.85
101-
flash-attn-4[13]==4.0.0b14

test/benchmark/service/benchmark_multiturn.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -317,20 +317,8 @@ def stream_one_turn(
317317
continue
318318

319319
if first_token_time is not None:
320-
generated_text = "".join(generated_text_parts)
321-
estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False))
322-
estimated_completion_tokens = max(estimated_completion_tokens, len(generated_text_parts))
323-
print(f"\n[turn warning] {e}; keeping partial turn with estimated usage (attempt={attempt + 1})")
324-
return {
325-
"ttft": first_token_time - start_time,
326-
"decode_times": decode_times,
327-
"prompt_tokens": prompt_tokens or prompt_token_len,
328-
"completion_tokens": completion_tokens or estimated_completion_tokens,
329-
"cached_tokens": cached_tokens,
330-
"cached_tokens_reported": cached_tokens_reported,
331-
"usage_estimated": completion_tokens == 0 or prompt_tokens == 0,
332-
"generated_text": generated_text,
333-
}
320+
print(f"\n[turn warning] {e}; discarding partial turn (attempt={attempt + 1})")
321+
return None
334322

335323
print(f"\n[turn exception] {e}")
336324
return None
@@ -344,15 +332,25 @@ def stream_one_turn(
344332
continue
345333
return None
346334

335+
generated_text = "".join(generated_text_parts)
336+
usage_estimated = False
337+
if prompt_tokens == 0:
338+
prompt_tokens = prompt_token_len
339+
usage_estimated = True
340+
if completion_tokens == 0:
341+
estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False))
342+
completion_tokens = max(estimated_completion_tokens, len(generated_text_parts))
343+
usage_estimated = True
344+
347345
return {
348346
"ttft": first_token_time - start_time,
349347
"decode_times": decode_times,
350348
"prompt_tokens": prompt_tokens,
351349
"completion_tokens": completion_tokens,
352350
"cached_tokens": cached_tokens,
353351
"cached_tokens_reported": cached_tokens_reported,
354-
"usage_estimated": False,
355-
"generated_text": "".join(generated_text_parts),
352+
"usage_estimated": usage_estimated,
353+
"generated_text": generated_text,
356354
}
357355

358356
return None
@@ -402,8 +400,9 @@ def run_session(
402400
print(
403401
f"\rconc={progress_state['concurrency']} "
404402
f"finished_turns={progress_state['finished_turns']} "
405-
f"active_sessions={progress_state['active_sessions']}",
403+
f"active_sessions={progress_state['active_sessions']}\033[K",
406404
end="",
405+
flush=True,
407406
)
408407
turn_input_len = rng.randint(min_turn_input_increment, turn_input_increment)
409408
prompt, prompt_len = append_turn_input(

0 commit comments

Comments
 (0)