Skip to content

Commit ac368f4

Browse files
author
niushengxiao
committed
fix: fix bugs
1 parent 41f3947 commit ac368f4

3 files changed

Lines changed: 17 additions & 20 deletions

File tree

lightllm/common/kv_cache_mem_manager/mem_manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ def profile_size(self, mem_fraction):
6767

6868
torch.cuda.empty_cache()
6969
world_size = dist.get_world_size()
70-
71-
available_memory = get_available_gpu_memory(world_size) * mem_fraction
70+
available_memory = get_available_gpu_memory(world_size) - get_total_gpu_memory() * (1 - mem_fraction)
7271
cell_size = self.get_cell_size()
7372
self.size = int(available_memory * 1024 ** 3 / cell_size)
7473
if world_size > 1:

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,3 @@ nixl==1.1.0
9898
xformers==0.0.35
9999
redis==7.3.0
100100
litellm>=1.52.0,<1.85
101-
flash-attn-4[13]==4.0.0b14

test/benchmark/service/benchmark_multiturn.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -317,20 +317,8 @@ def stream_one_turn(
317317
continue
318318

319319
if first_token_time is not None:
320-
generated_text = "".join(generated_text_parts)
321-
estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False))
322-
estimated_completion_tokens = max(estimated_completion_tokens, len(generated_text_parts))
323-
print(f"\n[turn warning] {e}; keeping partial turn with estimated usage (attempt={attempt + 1})")
324-
return {
325-
"ttft": first_token_time - start_time,
326-
"decode_times": decode_times,
327-
"prompt_tokens": prompt_tokens or prompt_token_len,
328-
"completion_tokens": completion_tokens or estimated_completion_tokens,
329-
"cached_tokens": cached_tokens,
330-
"cached_tokens_reported": cached_tokens_reported,
331-
"usage_estimated": completion_tokens == 0 or prompt_tokens == 0,
332-
"generated_text": generated_text,
333-
}
320+
print(f"\n[turn warning] {e}; discarding partial turn (attempt={attempt + 1})")
321+
return None
334322

335323
print(f"\n[turn exception] {e}")
336324
return None
@@ -344,15 +332,25 @@ def stream_one_turn(
344332
continue
345333
return None
346334

335+
generated_text = "".join(generated_text_parts)
336+
usage_estimated = False
337+
if prompt_tokens == 0:
338+
prompt_tokens = prompt_token_len
339+
usage_estimated = True
340+
if completion_tokens == 0:
341+
estimated_completion_tokens = len(tokenizer.encode(generated_text, add_special_tokens=False))
342+
completion_tokens = max(estimated_completion_tokens, len(generated_text_parts))
343+
usage_estimated = True
344+
347345
return {
348346
"ttft": first_token_time - start_time,
349347
"decode_times": decode_times,
350348
"prompt_tokens": prompt_tokens,
351349
"completion_tokens": completion_tokens,
352350
"cached_tokens": cached_tokens,
353351
"cached_tokens_reported": cached_tokens_reported,
354-
"usage_estimated": False,
355-
"generated_text": "".join(generated_text_parts),
352+
"usage_estimated": usage_estimated,
353+
"generated_text": generated_text,
356354
}
357355

358356
return None
@@ -402,8 +400,9 @@ def run_session(
402400
print(
403401
f"\rconc={progress_state['concurrency']} "
404402
f"finished_turns={progress_state['finished_turns']} "
405-
f"active_sessions={progress_state['active_sessions']}",
403+
f"active_sessions={progress_state['active_sessions']}\033[K",
406404
end="",
405+
flush=True,
407406
)
408407
turn_input_len = rng.randint(min_turn_input_increment, turn_input_increment)
409408
prompt, prompt_len = append_turn_input(

0 commit comments

Comments
 (0)