Skip to content

Commit 814d267

Browse files
author
niushengxiao
committed
fix: fix benchmark_multiturn.py
1 parent 4aee0ba commit 814d267

3 files changed

Lines changed: 30 additions & 7 deletions

File tree

lightllm/common/linear_att_cache_manager/linear_att_buffer_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ def __init__(
2626
dtype=self.linear_config.conv_state_dtype,
2727
shape=self.linear_config.get_conv_state_shape(),
2828
layer_num=self.linear_config.linear_layer_num,
29-
device="cuda",
29+
device="cpu",
3030
size_first=True,
3131
)
3232
self.ssm_state_cache = LayerCache(
3333
size=self.size,
3434
dtype=self.linear_config.ssm_state_dtype,
3535
shape=self.linear_config.get_ssm_state_shape(),
3636
layer_num=self.linear_config.linear_layer_num,
37-
device="cuda",
37+
device="cpu",
3838
size_first=True,
3939
)
4040
self.clear_to_init_state()

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,4 @@ nixl==1.1.0
9898
xformers==0.0.35
9999
redis==7.3.0
100100
litellm>=1.52.0,<1.85
101-
flash-attn-4[13]==4.0.0b13
101+
flash-attn-4[13]==4.0.0b14

test/benchmark/service/benchmark_multiturn.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
_STREAM_READ_BUFSIZE = 1 << 20
5151
_STREAM_MAX_LINE_SIZE = 1 << 20
5252
_DEFAULT_TRANSIENT_RETRIES = 2
53+
_PROMPT_LEN_OVERLAP_CHARS = 512
5354
_TRANSIENT_STREAM_ERRORS = (
5455
aiohttp.ServerDisconnectedError,
5556
aiohttp.ClientPayloadError,
@@ -177,6 +178,7 @@ def gen_session_initial_prompt(
177178
def append_turn_input(
178179
tokenizer,
179180
prompt: str,
181+
prompt_token_len: int,
180182
generated_text: str,
181183
turn_input_increment: int,
182184
rng: random.Random,
@@ -188,8 +190,22 @@ def append_turn_input(
188190
new_text = decode_ids(tokenizer, new_ids)
189191
else:
190192
new_text = ""
191-
new_prompt = prompt + generated_text + new_text
192-
new_len = len(tokenizer.encode(new_prompt, add_special_tokens=False))
193+
194+
appended_text = generated_text + new_text
195+
new_prompt = prompt + appended_text
196+
if not appended_text:
197+
return new_prompt, prompt_token_len
198+
199+
# Token merges only depend on a small boundary window, so avoid
200+
# re-encoding the entire prompt on every turn.
201+
overlap_text = prompt[-_PROMPT_LEN_OVERLAP_CHARS:]
202+
if overlap_text:
203+
overlap_token_len = len(tokenizer.encode(overlap_text, add_special_tokens=False))
204+
merged_token_len = len(tokenizer.encode(overlap_text + appended_text, add_special_tokens=False))
205+
appended_token_len = max(merged_token_len - overlap_token_len, 0)
206+
else:
207+
appended_token_len = len(tokenizer.encode(appended_text, add_special_tokens=False))
208+
new_len = prompt_token_len + appended_token_len
193209
return new_prompt, new_len
194210

195211

@@ -352,7 +368,12 @@ async def run_session(
352368
"""Run a single multi-turn dialogue session. Returns a list of per-turn
353369
stat dicts (same schema as stream_one_turn output)."""
354370
rng = random.Random(base_seed + session_id)
355-
prompt, prompt_len = gen_session_initial_prompt(tokenizer, start_input_len, base_seed + session_id)
371+
prompt, prompt_len = await asyncio.to_thread(
372+
gen_session_initial_prompt,
373+
tokenizer,
374+
start_input_len,
375+
base_seed + session_id,
376+
)
356377

357378
per_turn: List[Dict] = []
358379
turn_idx = 0
@@ -370,9 +391,11 @@ async def run_session(
370391
end="",
371392
)
372393
turn_input_len = rng.randint(min_turn_input_increment, turn_input_increment)
373-
prompt, prompt_len = append_turn_input(
394+
prompt, prompt_len = await asyncio.to_thread(
395+
append_turn_input,
374396
tokenizer,
375397
prompt,
398+
result["prompt_tokens"] or prompt_len,
376399
result["generated_text"],
377400
turn_input_len,
378401
rng,

0 commit comments

Comments
 (0)