Skip to content

Commit ebe59b2

Browse files
committed
Add usage.prompt_tokens_details.cached_tokens for prefix caching
1 parent 70d2682 commit ebe59b2

15 files changed

Lines changed: 240 additions & 31 deletions

File tree

benchmark/benchmark_chat_completion.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
33
This script focuses on eval-style JSONL dumps where each row contains OpenAI
44
chat ``messages``, or a string/list ``prompt`` (e.g. dapo-math-17k). List-type
5-
``prompt`` values are treated as message lists. It records streaming latency traces,
5+
``prompt`` values are treated as message lists. Optional per-row ``tools`` and
6+
``tool_choice`` fields are forwarded to ``/v1/chat/completions``. It records streaming latency traces,
67
aggregates TTFT/ITL/TPOT metrics, and writes table plus report artifacts for concurrency/RPS sweeps.
78
89
Generation options include ``--output-tokens`` (``max_completion_tokens``),
@@ -42,6 +43,8 @@ class BenchmarkRequest:
4243
messages: list[dict[str, Any]] = field(default_factory=list)
4344
input_ids: list[int] | None = None
4445
image_data: Any = None
46+
tools: list[dict[str, Any]] | None = None
47+
tool_choice: Any | None = None
4548

4649

4750
@dataclass
@@ -73,6 +76,7 @@ class RequestTrace:
7376
chunk_times: list[float] = field(default_factory=list)
7477
prompt_tokens: int = 0
7578
completion_tokens: int = 0
79+
cached_tokens: int = 0
7680
usage_available: bool = False
7781
generated_text: str = ''
7882
reasoning_text: str = ''
@@ -205,6 +209,21 @@ def _extract_messages(row: dict[str, Any]) -> list[dict[str, Any]]:
205209
raise ValueError('row must contain messages or prompt')
206210

207211

212+
def _extract_tools(row: dict[str, Any]) -> list[dict[str, Any]] | None:
213+
tools = row.get('tools')
214+
if not tools:
215+
return None
216+
if not isinstance(tools, list):
217+
raise ValueError('tools must be a list when present')
218+
return tools
219+
220+
221+
def _extract_tool_choice(row: dict[str, Any]) -> Any | None:
222+
if 'tool_choice' not in row:
223+
return None
224+
return row['tool_choice']
225+
226+
208227
def _normalize_row(
209228
row: dict[str, Any],
210229
dataset: str,
@@ -213,23 +232,35 @@ def _normalize_row(
213232
) -> BenchmarkRequest:
214233
request_id = str(row.get('id', f'{dataset}-{row_index}'))
215234
messages = _extract_messages(row)
235+
tools = _extract_tools(row)
236+
tool_choice = _extract_tool_choice(row)
216237

217238
if tokenizer is not None:
218-
prompt_str = tokenizer.apply_chat_template(
219-
messages,
220-
tokenize=False,
221-
add_generation_prompt=True,
222-
)
239+
template_kwargs: dict[str, Any] = {
240+
'tokenize': False,
241+
'add_generation_prompt': True,
242+
}
243+
if tools is not None:
244+
template_kwargs['tools'] = tools
245+
prompt_str = tokenizer.apply_chat_template(messages, **template_kwargs)
223246
return BenchmarkRequest(
224247
dataset=dataset,
225248
id=request_id,
226249
input_ids=tokenizer.encode(prompt_str, add_special_tokens=False),
227250
image_data=row.get('image_data'),
251+
tools=tools,
252+
tool_choice=tool_choice,
228253
)
229254

230255
if not messages:
231256
raise ValueError(f'row {row_index} in {dataset} has invalid messages')
232-
return BenchmarkRequest(dataset=dataset, id=request_id, messages=messages)
257+
return BenchmarkRequest(
258+
dataset=dataset,
259+
id=request_id,
260+
messages=messages,
261+
tools=tools,
262+
tool_choice=tool_choice,
263+
)
233264

234265

235266
def _read_raw_rows(
@@ -326,6 +357,13 @@ def parse_sse_line(line: bytes | str) -> SSEEvent:
326357
)
327358

328359

360+
def _cached_tokens_from_usage(usage: dict[str, Any] | None) -> int:
361+
if not usage:
362+
return 0
363+
details = usage.get('prompt_tokens_details') or {}
364+
return int(details.get('cached_tokens', 0) or 0)
365+
366+
329367
def build_payload(
330368
request: BenchmarkRequest,
331369
model: str,
@@ -372,6 +410,10 @@ def build_payload(
372410
payload['logprobs'] = True
373411
if top_logprobs is not None:
374412
payload['top_logprobs'] = top_logprobs
413+
if request.tools:
414+
payload['tools'] = request.tools
415+
if request.tool_choice is not None:
416+
payload['tool_choice'] = request.tool_choice
375417
if extra_body:
376418
payload.update(extra_body)
377419
return payload
@@ -482,6 +524,7 @@ async def request_chat_completion(
482524
trace.completion_tokens = int(
483525
event.usage.get('completion_tokens', trace.completion_tokens) or 0
484526
)
527+
trace.cached_tokens = _cached_tokens_from_usage(event.usage)
485528
if event.routed_experts and shared_store is not None:
486529
try:
487530
await fetch_routed_experts(shared_store, event.routed_experts)
@@ -632,6 +675,7 @@ def aggregate_traces(traces: Sequence[RequestTrace]) -> list[dict[str, Any]]:
632675
duration = max(end - start, 0.0)
633676
total_input = sum(trace.prompt_tokens for trace in completed)
634677
total_output = sum(trace.completion_tokens for trace in completed)
678+
total_cached = sum(trace.cached_tokens for trace in completed)
635679
itls = [itl for trace in completed for itl in trace.itls_s]
636680

637681
summary: dict[str, Any] = {
@@ -646,6 +690,8 @@ def aggregate_traces(traces: Sequence[RequestTrace]) -> list[dict[str, Any]]:
646690
'duration_s': duration,
647691
'total_input_tokens': total_input,
648692
'total_output_tokens': total_output,
693+
'total_cached_tokens': total_cached,
694+
'cache_hit_rate': total_cached / total_input if total_input > 0 else 0.0,
649695
'request_throughput_req_s': len(completed) / duration if duration > 0 else 0.0,
650696
'input_throughput_tok_s': total_input / duration if duration > 0 else 0.0,
651697
'output_throughput_tok_s': total_output / duration if duration > 0 else 0.0,
@@ -691,6 +737,7 @@ def _write_requests_csv(path: Path, rows: Sequence[dict[str, Any]]) -> None:
691737
'e2e_latency_s',
692738
'prompt_tokens',
693739
'completion_tokens',
740+
'cached_tokens',
694741
'usage_available',
695742
'finish_reason',
696743
'error',

lmdeploy/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ class Response:
536536
last_hidden_state: torch.Tensor = None
537537
index: int = 0
538538
routed_experts: Any = None
539+
cached_tokens: int = 0
539540

540541
def __str__(self):
541542
return f'text={self.text}\n{self._format_none_text_fields()}'
@@ -651,6 +652,7 @@ class RequestMetrics:
651652
token_timestamp: float = 0.0
652653
engine_events: list[EngineEvent] = field(default_factory=list)
653654
spec_info: dict[str, Any] | None = None
655+
cached_tokens: int = 0
654656

655657

656658
@dataclass
@@ -674,6 +676,7 @@ class EngineOutput:
674676
cache_block_ids: list[int] | None = None
675677
req_metrics: RequestMetrics | None = None
676678
routed_experts: torch.Tensor = None
679+
cached_tokens: int = 0
677680

678681

679682
@dataclass

lmdeploy/metrics/loggers.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,26 @@ def __init__(self, model_name: str, max_model_len: int, dp_rank: int = 0):
271271
buckets=build_1_2_5_buckets(max_model_len),
272272
labelnames=labelnames).labels(*labelvalues)
273273

274+
self.histogram_num_cached_tokens_request = \
275+
prometheus_client.Histogram(
276+
name='lmdeploy:request_cached_tokens',
277+
documentation='Number of prefix-cached input tokens per request.',
278+
buckets=build_1_2_5_buckets(max_model_len),
279+
labelnames=labelnames).labels(*labelvalues)
280+
281+
self.histogram_cache_hit_ratio_request = \
282+
prometheus_client.Histogram(
283+
name='lmdeploy:request_cache_hit_ratio',
284+
documentation='Prefix cache hit ratio (cached_tokens / prompt_tokens) per request.',
285+
buckets=[0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
286+
labelnames=labelnames).labels(*labelvalues)
287+
288+
self.counter_cached_tokens_total = \
289+
prometheus_client.Counter(
290+
name='lmdeploy:cached_tokens_total',
291+
documentation='Total prefix-cached input tokens served.',
292+
labelnames=labelnames).labels(*labelvalues)
293+
274294
self.histogram_iteration_tokens = \
275295
prometheus_client.Histogram(
276296
name='lmdeploy:iteration_tokens_total',
@@ -385,6 +405,10 @@ def record_finish(self, stats: RequestStats) -> None:
385405
self.histogram_decode_time_request.observe(stats.decode_time_interval)
386406
self.histogram_num_prompt_tokens_request.observe(stats.prompt_tokens)
387407
self.histogram_num_generation_tokens_request.observe(stats.generation_tokens)
408+
self.histogram_num_cached_tokens_request.observe(stats.cached_tokens)
409+
if stats.prompt_tokens > 0:
410+
self.histogram_cache_hit_ratio_request.observe(stats.cached_tokens / stats.prompt_tokens)
411+
self.counter_cached_tokens_total.inc(stats.cached_tokens)
388412

389413
@staticmethod
390414
def _get_counter_value(counter) -> float:

lmdeploy/metrics/metrics_processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ async def _run_metrics_handler(self):
5151
outputs, req_stats, iteration_stats, specdecode_stats = update_data
5252

5353
# update request stats
54+
if outputs:
55+
req_stats.cached_tokens = outputs.cached_tokens
5456
if outputs and outputs.req_metrics:
5557
# when users visit "/abort_request" endpoint, `req_metrics` might be None
5658
req_stats.update_from_events(outputs.req_metrics.engine_events)

lmdeploy/metrics/stats.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ def __init__(self, arrival_time: float = None, prompt_tokens: int = 0):
124124
self.prompt_tokens = prompt_tokens
125125

126126
self.generation_tokens: int = 0
127+
self.cached_tokens: int = 0
127128
self.queued_time: float = 0.0
128129
self.scheduled_time: float = 0.0
129130
self.first_token_time: float = 0.0

lmdeploy/pytorch/engine/engine_instance.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ async def async_stream_infer(self,
224224

225225
cache_block_ids = resp.data.get('cache_block_ids', None) if resp.data else None
226226
req_metrics = resp.data.get('req_metrics', None) if resp.data else None
227+
cached_tokens = req_metrics.cached_tokens if req_metrics is not None else 0
227228
logprobs = resp.data.pop('logprobs', None) if resp.data else None
228229

229230
if resp.type == ResponseType.SUCCESS:
@@ -234,6 +235,7 @@ async def async_stream_infer(self,
234235
token_ids[output_offset:].tolist(),
235236
cache_block_ids=cache_block_ids,
236237
req_metrics=req_metrics,
238+
cached_tokens=cached_tokens,
237239
logprobs=logprobs)
238240
output_offset = len(token_ids)
239241
elif resp.type in (ResponseType.FINISH, ResponseType.CANCEL):
@@ -258,6 +260,7 @@ async def async_stream_infer(self,
258260
logits=logits,
259261
cache_block_ids=cache_block_ids,
260262
req_metrics=req_metrics,
263+
cached_tokens=cached_tokens,
261264
routed_experts=routed_experts,
262265
logprobs=logprobs)
263266
break

lmdeploy/pytorch/engine/engine_loop.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,10 @@ def __get_logprobs(batched_outputs: 'BatchedOutputs'):
351351
if num_draft_tokens is not None and model_inputs is None and self.config.enable_metrics:
352352
num_accepted_tokens = (batched_outputs.next_token_ids[idx] > -1).sum() - 1
353353
spec_info = dict(num_draft_tokens=num_draft_tokens, num_accepted_tokens=num_accepted_tokens.item())
354-
req_metrics = RequestMetrics(new_token_timestamp, msg.engine_events, spec_info=spec_info)
354+
req_metrics = RequestMetrics(new_token_timestamp,
355+
msg.engine_events,
356+
spec_info=spec_info,
357+
cached_tokens=msg.prefix_cache_hit_tokens)
355358
out = InferOutput(session_id=session_id,
356359
resp=msg.resp,
357360
finish=finish,

lmdeploy/pytorch/messages.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,9 @@ class SchedulerSequence:
657657
# mrope
658658
history_mrope_pos_ids: HistoryMropePosIds = field(default_factory=HistoryMropePosIds)
659659

660+
# prefix caching
661+
prefix_cache_hit_tokens: int = 0
662+
660663
def __post_init__(self):
661664
"""Post init."""
662665
self._seq_meta: SequenceMeta = self.session.seq_meta

lmdeploy/pytorch/paging/block_trie.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def get_root(self, adapter_name: str):
8383
def match(self, seq: SchedulerSequence):
8484
"""Match sequence and cache."""
8585
if not self.enable:
86+
seq.prefix_cache_hit_tokens = 0
8687
return
8788

8889
block_size = self.block_size
@@ -124,6 +125,7 @@ def __match_success(node: Node):
124125
# record prefix hit
125126
self.stats.num_query_tokens += seq.num_all_ids - init_num_matched
126127
self.stats.num_hit_tokens += num_matched - init_num_matched
128+
seq.prefix_cache_hit_tokens = num_matched - init_num_matched
127129

128130
seq.logical_blocks.last_shared_node = curr
129131

lmdeploy/pytorch/paging/scheduler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def _reorder_migrating():
159159
max_batches = self.scheduler_config.max_batches - self.num_ready() - self.num_running()
160160
while len(migration_waiting) > 0 and len(migration_ready) < max_batches:
161161
seq = migration_waiting.pop(0)
162-
self.block_trie.match(migration_waiting)
162+
self.block_trie.match(seq)
163163
if not __evict_for_seq(seq, migration_waiting):
164164
break
165165

0 commit comments

Comments
 (0)