Skip to content

Commit 80a88bf

Browse files
tianmu-liclaude
andcommitted
feat: pre-compute ISL token counts for multi-turn dataset-history mode
- Add _precompute_isl_for_multi_turn() in execute.py: runs apply_chat_template(messages, tokenize=True, add_generation_prompt=True) once per client turn at setup time and stores results in sample["input_tokens"], hitting the IslTrigger sync fast path (len(token_ids)) with zero hot-path cost. - Add _extract_prompt_text() in session.py: refactors inline message content extraction to handle list-form multimodal content safely, fixing a crash when content is a list (e.g. vision/tool-call messages). - Add unit tests for both helpers and two integration tests covering target_concurrency cap enforcement and pipeline exception propagation. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 8025c45 commit 80a88bf

5 files changed

Lines changed: 328 additions & 6 deletions

File tree

src/inference_endpoint/commands/benchmark/execute.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,42 @@ def _load_datasets(
288288
return dataloader, accuracy_datasets, eval_configs
289289

290290

291+
def _precompute_isl_for_multi_turn(
292+
dataloader: MultiTurnDataset, tokenizer_name: str
293+
) -> None:
294+
"""Tokenize pre-built message lists and store token counts in each sample.
295+
296+
Runs apply_chat_template once per client turn so the hot-path IslTrigger
297+
sync path (len(token_ids)) is used instead of on-the-fly text tokenization.
298+
Only affects dataset-history turns; live-history turns override 'messages'
299+
at runtime so the stored input_tokens are stale (acceptable approximation).
300+
"""
301+
# Local import: optional dependency, circular-import avoidance (consistent
302+
# with _annotate_response_token_counts in this file).
303+
from transformers import AutoTokenizer # noqa: PLC0415
304+
305+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
306+
skipped = 0
307+
for sample in dataloader.data or []:
308+
messages = sample.get("messages")
309+
if not messages:
310+
continue
311+
try:
312+
token_ids: list[int] = tokenizer.apply_chat_template(
313+
messages,
314+
tokenize=True,
315+
add_generation_prompt=True,
316+
)
317+
sample["input_tokens"] = token_ids
318+
except Exception: # template errors vary by model; skip gracefully
319+
skipped += 1
320+
if skipped:
321+
logger.warning(
322+
"ISL pre-computation: %d turn(s) skipped (apply_chat_template failed)",
323+
skipped,
324+
)
325+
326+
291327
def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkContext:
292328
"""Load tokenizer, dataset, create scheduler, setup report dir."""
293329
# CPU affinity
@@ -317,6 +353,10 @@ def setup_benchmark(config: BenchmarkConfig, test_mode: TestMode) -> BenchmarkCo
317353
# Datasets
318354
dataloader, accuracy_datasets, eval_configs = _load_datasets(config, report_dir)
319355

356+
if isinstance(dataloader, MultiTurnDataset) and tokenizer_name is not None:
357+
logger.info("Pre-computing ISL token counts for multi-turn dataset…")
358+
_precompute_isl_for_multi_turn(dataloader, tokenizer_name)
359+
320360
# Setup runtime settings using factory method
321361
rt_settings = RuntimeSettings.from_config(config, dataloader.num_samples())
322362

src/inference_endpoint/load_generator/session.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,26 @@
4747
_WARMUP_ENABLED = os.environ.get("ENABLE_WARMUP") == "1"
4848

4949

50+
def _extract_prompt_text(messages: list[Any]) -> str | None:
51+
"""Join text content from an OpenAI messages list; handles list-form multimodal content."""
52+
parts: list[str] = []
53+
for m in messages:
54+
if not isinstance(m, dict):
55+
continue
56+
c = m.get("content")
57+
if isinstance(c, str) and c:
58+
parts.append(c)
59+
elif isinstance(c, list):
60+
parts.extend(
61+
p["text"]
62+
for p in c
63+
if isinstance(p, dict)
64+
and p.get("type") == "text"
65+
and isinstance(p.get("text"), str)
66+
)
67+
return "\n".join(parts) if parts else None
68+
69+
5070
# ---------------------------------------------------------------------------
5171
# Phase configuration
5272
# ---------------------------------------------------------------------------
@@ -204,12 +224,7 @@ def issue(
204224
token_ids = data.get("input_tokens") or data.get("token_ids")
205225
prompt_text = data.get("prompt")
206226
if prompt_text is None and "messages" in data:
207-
parts: list[str] = [
208-
m["content"]
209-
for m in data["messages"]
210-
if isinstance(m, dict) and m.get("content")
211-
]
212-
prompt_text = "\n".join(parts) if parts else None
227+
prompt_text = _extract_prompt_text(data["messages"])
213228
prompt_data = PromptData(
214229
text=prompt_text,
215230
token_ids=tuple(token_ids) if token_ids is not None else None,

tests/integration/test_multi_turn.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def _make_dataset(rows: list[dict]) -> MultiTurnDataset:
7777
def _make_strategy(
7878
ds: MultiTurnDataset,
7979
use_dataset_history: bool = True,
80+
target_concurrency: int | None = None,
8081
) -> MultiTurnStrategy:
8182
mt_cfg = MultiTurnConfig(
8283
turn_timeout_s=10.0,
@@ -86,6 +87,7 @@ def _make_strategy(
8687
conversation_manager=ConversationManager(),
8788
dataset_metadata=ds.conversation_metadata,
8889
multi_turn_config=mt_cfg,
90+
target_concurrency=target_concurrency,
8991
)
9092

9193

@@ -600,6 +602,114 @@ async def test_concurrent_conversations_stress(echo_server):
600602
assert len(responses) == expected_client_turns
601603

602604

605+
@pytest.mark.integration
606+
@pytest.mark.asyncio
607+
async def test_multi_turn_active_conversations_respects_target_concurrency(echo_server):
608+
num_convs = 20
609+
rows = []
610+
for i in range(num_convs):
611+
conv_id = f"cap_conv_{i}"
612+
rows += [
613+
{
614+
"conversation_id": conv_id,
615+
"turn": 1,
616+
"role": "user",
617+
"content": f"Q1-{i}",
618+
},
619+
{
620+
"conversation_id": conv_id,
621+
"turn": 2,
622+
"role": "assistant",
623+
"content": f"A1-{i}",
624+
},
625+
{
626+
"conversation_id": conv_id,
627+
"turn": 3,
628+
"role": "user",
629+
"content": f"Q2-{i}",
630+
},
631+
]
632+
633+
ds = _make_dataset(rows)
634+
strategy = _make_strategy(ds, target_concurrency=4)
635+
responses: dict = {}
636+
637+
observed_max: list[int] = []
638+
orig_on_sample_complete = strategy.on_sample_complete
639+
640+
def tracked_on_sample_complete(result) -> None:
641+
observed_max.append(len(strategy._active_iters))
642+
orig_on_sample_complete(result)
643+
644+
strategy.on_sample_complete = tracked_on_sample_complete
645+
646+
await _run_session(echo_server.url, ds, strategy, responses)
647+
648+
assert len(responses) == num_convs * 2 # 2 client turns per conversation
649+
assert max(observed_max, default=0) <= 4
650+
651+
652+
@pytest.mark.integration
653+
@pytest.mark.asyncio
654+
async def test_multi_turn_pipeline_exception_propagates(echo_server):
655+
rows = [
656+
{"conversation_id": "err_c1", "turn": 1, "role": "user", "content": "Q1"},
657+
{"conversation_id": "err_c1", "turn": 2, "role": "assistant", "content": "A1"},
658+
{"conversation_id": "err_c1", "turn": 3, "role": "user", "content": "Q2"},
659+
]
660+
ds = _make_dataset(rows)
661+
strategy = _make_strategy(ds)
662+
663+
call_count = 0
664+
orig_issue_next_turn = strategy._issue_next_turn
665+
666+
def failing_issue_next_turn(*args, **kwargs):
667+
nonlocal call_count
668+
call_count += 1
669+
if call_count >= 2:
670+
raise RuntimeError("injected pipeline error")
671+
return orig_issue_next_turn(*args, **kwargs)
672+
673+
strategy._issue_next_turn = failing_issue_next_turn
674+
675+
loop = asyncio.get_running_loop()
676+
http_config = HTTPClientConfig(
677+
endpoint_urls=[urljoin(echo_server.url, "/v1/chat/completions")],
678+
warmup_connections=0,
679+
num_workers=2,
680+
)
681+
http_client = await HTTPEndpointClient.create(http_config, loop)
682+
issuer = HttpClientSampleIssuer(http_client)
683+
684+
try:
685+
session = BenchmarkSession(
686+
issuer=issuer,
687+
event_publisher=_NoOpPublisher(),
688+
loop=loop,
689+
on_sample_complete=strategy.on_sample_complete,
690+
)
691+
rt = RuntimeSettings(
692+
metrics.Throughput(1000),
693+
[metrics.Throughput(1000)],
694+
min_duration_ms=0,
695+
max_duration_ms=30_000,
696+
n_samples_from_dataset=ds.num_samples(),
697+
n_samples_to_issue=ds.num_samples(),
698+
min_sample_count=1,
699+
rng_sched=random.Random(42),
700+
rng_sample_index=random.Random(42),
701+
load_pattern=LoadPattern(type=LoadPatternType.MAX_THROUGHPUT),
702+
)
703+
phase = PhaseConfig("perf", rt, ds, PhaseType.PERFORMANCE, strategy=strategy)
704+
705+
with pytest.raises(RuntimeError, match="injected pipeline error"):
706+
await asyncio.wait_for(session.run([phase]), timeout=30.0)
707+
708+
assert strategy._inflight == {}
709+
finally:
710+
await http_client.shutdown_async()
711+
712+
603713
@pytest.mark.integration
604714
@pytest.mark.asyncio
605715
async def test_tools_field_forwarded_to_endpoint(echo_server):
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Unit tests for _precompute_isl_for_multi_turn."""
17+
18+
from unittest.mock import MagicMock, patch
19+
20+
import pytest
21+
from inference_endpoint.commands.benchmark.execute import _precompute_isl_for_multi_turn
22+
23+
24+
def _make_dataloader(samples: list[dict]) -> MagicMock:
25+
dl = MagicMock()
26+
dl.data = samples
27+
return dl
28+
29+
30+
class TestPrecomputeIslForMultiTurn:
31+
@pytest.mark.unit
32+
def test_sets_input_tokens_for_samples_with_messages(self):
33+
samples = [
34+
{"messages": [{"role": "user", "content": "hello"}]},
35+
{"messages": [{"role": "user", "content": "world"}]},
36+
]
37+
dataloader = _make_dataloader(samples)
38+
mock_tokenizer = MagicMock()
39+
mock_tokenizer.apply_chat_template.side_effect = lambda msgs, **_: list(
40+
range(len(msgs) * 3)
41+
)
42+
43+
with patch("transformers.AutoTokenizer") as mock_cls:
44+
mock_cls.from_pretrained.return_value = mock_tokenizer
45+
_precompute_isl_for_multi_turn(dataloader, "test-model")
46+
47+
for sample in samples:
48+
assert "input_tokens" in sample
49+
assert isinstance(sample["input_tokens"], list)
50+
51+
@pytest.mark.unit
52+
def test_leaves_samples_without_messages_untouched(self):
53+
samples = [
54+
{"prompt": "no messages here"},
55+
{"input_tokens": [1, 2, 3]},
56+
]
57+
dataloader = _make_dataloader(samples)
58+
mock_tokenizer = MagicMock()
59+
60+
with patch("transformers.AutoTokenizer") as mock_cls:
61+
mock_cls.from_pretrained.return_value = mock_tokenizer
62+
_precompute_isl_for_multi_turn(dataloader, "test-model")
63+
64+
mock_tokenizer.apply_chat_template.assert_not_called()
65+
assert "input_tokens" not in samples[0]
66+
assert samples[1]["input_tokens"] == [1, 2, 3]
67+
68+
@pytest.mark.unit
69+
def test_skips_failed_template_calls_with_warning(self, caplog):
70+
samples = [
71+
{"messages": [{"role": "user", "content": "good"}]},
72+
{"messages": [{"role": "user", "content": "bad"}]},
73+
]
74+
dataloader = _make_dataloader(samples)
75+
76+
def side_effect(msgs, **_):
77+
if msgs[0]["content"] == "bad":
78+
raise ValueError("template error")
79+
return [10, 20, 30]
80+
81+
mock_tokenizer = MagicMock()
82+
mock_tokenizer.apply_chat_template.side_effect = side_effect
83+
84+
with patch("transformers.AutoTokenizer") as mock_cls:
85+
mock_cls.from_pretrained.return_value = mock_tokenizer
86+
with caplog.at_level("WARNING"):
87+
_precompute_isl_for_multi_turn(dataloader, "test-model")
88+
89+
assert "input_tokens" in samples[0]
90+
assert "input_tokens" not in samples[1]
91+
assert "1 turn(s) skipped" in caplog.text
92+
93+
@pytest.mark.unit
94+
def test_add_generation_prompt_true(self):
95+
samples = [{"messages": [{"role": "user", "content": "hi"}]}]
96+
dataloader = _make_dataloader(samples)
97+
mock_tokenizer = MagicMock()
98+
mock_tokenizer.apply_chat_template.return_value = [1, 2, 3]
99+
100+
with patch("transformers.AutoTokenizer") as mock_cls:
101+
mock_cls.from_pretrained.return_value = mock_tokenizer
102+
_precompute_isl_for_multi_turn(dataloader, "test-model")
103+
104+
_, kwargs = mock_tokenizer.apply_chat_template.call_args
105+
assert kwargs.get("add_generation_prompt") is True
106+
assert kwargs.get("tokenize") is True

tests/unit/load_generator/test_async_session.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
PhaseResult,
4040
PhaseType,
4141
SessionResult,
42+
_extract_prompt_text,
4243
)
4344
from inference_endpoint.metrics.metric import Throughput
4445

@@ -882,3 +883,53 @@ def test_perf_results_filter(self, enable_warmup):
882883
assert len(sr.perf_results) == 2
883884
assert len(sr.accuracy_results) == 1
884885
assert sr.perf_results[0].name == "perf1"
886+
887+
888+
@pytest.mark.unit
889+
class TestExtractPromptText:
890+
def test_string_content_extracted(self):
891+
messages = [
892+
{"role": "user", "content": "Hello"},
893+
{"role": "assistant", "content": "Hi"},
894+
]
895+
assert _extract_prompt_text(messages) == "Hello\nHi"
896+
897+
def test_multimodal_list_content_text_parts_extracted(self):
898+
messages = [
899+
{
900+
"role": "user",
901+
"content": [
902+
{"type": "text", "text": "Describe this image"},
903+
{"type": "image_url"},
904+
],
905+
}
906+
]
907+
assert _extract_prompt_text(messages) == "Describe this image"
908+
909+
def test_mixed_string_and_list_content(self):
910+
messages = [
911+
{"role": "system", "content": "You are helpful"},
912+
{
913+
"role": "user",
914+
"content": [
915+
{"type": "text", "text": "What is this?"},
916+
{"type": "image_url"},
917+
],
918+
},
919+
]
920+
assert _extract_prompt_text(messages) == "You are helpful\nWhat is this?"
921+
922+
def test_none_content_skipped(self):
923+
messages = [
924+
{"role": "assistant", "content": None},
925+
{"role": "user", "content": "Hello"},
926+
]
927+
assert _extract_prompt_text(messages) == "Hello"
928+
929+
def test_list_content_with_no_text_parts_returns_none(self):
930+
messages = [{"role": "user", "content": [{"type": "image_url"}]}]
931+
assert _extract_prompt_text(messages) is None
932+
933+
def test_non_dict_messages_skipped(self):
934+
messages = ["not a dict", {"role": "user", "content": "Valid"}]
935+
assert _extract_prompt_text(messages) == "Valid"

0 commit comments

Comments
 (0)