5050_STREAM_READ_BUFSIZE = 1 << 20
5151_STREAM_MAX_LINE_SIZE = 1 << 20
5252_DEFAULT_TRANSIENT_RETRIES = 2
53+ _PROMPT_LEN_OVERLAP_CHARS = 512
5354_TRANSIENT_STREAM_ERRORS = (
5455 aiohttp .ServerDisconnectedError ,
5556 aiohttp .ClientPayloadError ,
@@ -177,6 +178,7 @@ def gen_session_initial_prompt(
177178def append_turn_input (
178179 tokenizer ,
179180 prompt : str ,
181+ prompt_token_len : int ,
180182 generated_text : str ,
181183 turn_input_increment : int ,
182184 rng : random .Random ,
@@ -188,8 +190,22 @@ def append_turn_input(
188190 new_text = decode_ids (tokenizer , new_ids )
189191 else :
190192 new_text = ""
191- new_prompt = prompt + generated_text + new_text
192- new_len = len (tokenizer .encode (new_prompt , add_special_tokens = False ))
193+
194+ appended_text = generated_text + new_text
195+ new_prompt = prompt + appended_text
196+ if not appended_text :
197+ return new_prompt , prompt_token_len
198+
199+ # Token merges only depend on a small boundary window, so avoid
200+ # re-encoding the entire prompt on every turn.
201+ overlap_text = prompt [- _PROMPT_LEN_OVERLAP_CHARS :]
202+ if overlap_text :
203+ overlap_token_len = len (tokenizer .encode (overlap_text , add_special_tokens = False ))
204+ merged_token_len = len (tokenizer .encode (overlap_text + appended_text , add_special_tokens = False ))
205+ appended_token_len = max (merged_token_len - overlap_token_len , 0 )
206+ else :
207+ appended_token_len = len (tokenizer .encode (appended_text , add_special_tokens = False ))
208+ new_len = prompt_token_len + appended_token_len
193209 return new_prompt , new_len
194210
195211
@@ -352,7 +368,12 @@ async def run_session(
352368 """Run a single multi-turn dialogue session. Returns a list of per-turn
353369 stat dicts (same schema as stream_one_turn output)."""
354370 rng = random .Random (base_seed + session_id )
355- prompt , prompt_len = gen_session_initial_prompt (tokenizer , start_input_len , base_seed + session_id )
371+ prompt , prompt_len = await asyncio .to_thread (
372+ gen_session_initial_prompt ,
373+ tokenizer ,
374+ start_input_len ,
375+ base_seed + session_id ,
376+ )
356377
357378 per_turn : List [Dict ] = []
358379 turn_idx = 0
@@ -370,9 +391,11 @@ async def run_session(
370391 end = "" ,
371392 )
372393 turn_input_len = rng .randint (min_turn_input_increment , turn_input_increment )
373- prompt , prompt_len = append_turn_input (
394+ prompt , prompt_len = await asyncio .to_thread (
395+ append_turn_input ,
374396 tokenizer ,
375397 prompt ,
398+ result ["prompt_tokens" ] or prompt_len ,
376399 result ["generated_text" ],
377400 turn_input_len ,
378401 rng ,
0 commit comments