Skip to content

Commit 734a146

Browse files
authored
[TRTLLM-12123][feat] Add per-iteration request-aggregate counters to InflightBatchingStats (#13199)
Signed-off-by: Yuewei Na <nv-yna@users.noreply.github.com> Co-authored-by: Yuewei Na <nv-yna@users.noreply.github.com>
1 parent 00218e5 commit 734a146

9 files changed

Lines changed: 813 additions & 4 deletions

File tree

cpp/include/tensorrt_llm/executor/types.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,35 @@ struct InflightBatchingStats
321321
SizeType32 microBatchId;
322322
/// @brief Average number of tokens decoded per request per iteration
323323
float avgNumDecodedTokensPerIter;
324+
/// @brief Context tokens for scheduled context requests that are read from
325+
/// KV cache rather than computed this iteration. Covers prefix-cache hits
326+
/// and previously-chunked tokens for chunked-prefill continuations.
327+
/// Complements @ref numCtxTokens (tokens computed this iteration).
328+
SizeType32 numCtxKvTokens;
329+
/// @brief Total KV context length (prompt + generated-so-far) summed
330+
/// across scheduled generation (decode) requests.
331+
SizeType32 numGenKvTokens;
332+
/// @brief Number of context (prefill) requests waiting in the executor
333+
/// request queue — submitted but not yet scheduled. Excludes non-normal
334+
/// control items (shutdown/cancel) and requests without a payload.
335+
SizeType32 numQueuedContextRequests;
336+
/// @brief Sum of prompt-token counts across queued context requests (the
337+
/// requests counted in @ref numQueuedContextRequests).
338+
SizeType32 numQueuedCtxTokens;
339+
/// @brief Number of generation-only requests waiting in the executor
340+
/// request queue. On a disaggregated-decode engine these are requests
341+
/// that have completed prefill elsewhere and are awaiting KV-cache
342+
/// transfer before they can start decoding. Always 0 on a
343+
/// non-disaggregated or disaggregated-prefill engine.
344+
SizeType32 numQueuedGenRequests;
345+
/// @brief Sum of prompt-token counts across queued generation-only
346+
/// requests (the requests counted in @ref numQueuedGenRequests). Acts
347+
/// as the KV-budget these requests will need once their KV transfer
348+
/// completes.
349+
SizeType32 numQueuedGenKvTokens;
350+
/// @brief Total KV context length summed across paused (preempted-decode)
351+
/// requests. Complements @ref numPausedRequests (count).
352+
SizeType32 numPausedKvTokens;
324353
};
325354

326355
/// @brief Struct that holds speculative decoding stats

cpp/tensorrt_llm/executor/jsonSerialization.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(KvCacheStats, maxNumBlocks, freeNumBlocks, us
3030
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(
3131
StaticBatchingStats, numScheduledRequests, numContextRequests, numCtxTokens, numGenTokens, emptyGenSlots);
3232
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(InflightBatchingStats, numScheduledRequests, numContextRequests, numGenRequests,
33-
numPausedRequests, numCtxTokens, microBatchId, avgNumDecodedTokensPerIter);
33+
numPausedRequests, numCtxTokens, microBatchId, avgNumDecodedTokensPerIter, numCtxKvTokens, numGenKvTokens,
34+
numQueuedContextRequests, numQueuedCtxTokens, numQueuedGenRequests, numQueuedGenKvTokens, numPausedKvTokens);
3435
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(SpecDecodingStats, numDraftTokens, numAcceptedTokens, numRequestsWithDraftTokens,
3536
acceptanceLength, iterLatencyMS, draftOverhead);
3637
NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(IterationStats, timestamp, iter, iterLatencyMS, newActiveRequestsQueueLatencyMS,

cpp/tensorrt_llm/executor/serialization.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1897,8 +1897,16 @@ InflightBatchingStats Serialization::deserializeInflightBatchingStats(std::istre
18971897
auto numCtxTokens = su::deserialize<SizeType32>(is);
18981898
auto microBatchId = su::deserialize<SizeType32>(is);
18991899
auto avgNumDecodedTokensPerIter = su::deserialize<float>(is);
1900+
auto numCtxKvTokens = su::deserialize<SizeType32>(is);
1901+
auto numGenKvTokens = su::deserialize<SizeType32>(is);
1902+
auto numQueuedContextRequests = su::deserialize<SizeType32>(is);
1903+
auto numQueuedCtxTokens = su::deserialize<SizeType32>(is);
1904+
auto numQueuedGenRequests = su::deserialize<SizeType32>(is);
1905+
auto numQueuedGenKvTokens = su::deserialize<SizeType32>(is);
1906+
auto numPausedKvTokens = su::deserialize<SizeType32>(is);
19001907
return InflightBatchingStats{numScheduledRequests, numContextRequests, numGenRequests, numPausedRequests,
1901-
numCtxTokens, microBatchId, avgNumDecodedTokensPerIter};
1908+
numCtxTokens, microBatchId, avgNumDecodedTokensPerIter, numCtxKvTokens, numGenKvTokens,
1909+
numQueuedContextRequests, numQueuedCtxTokens, numQueuedGenRequests, numQueuedGenKvTokens, numPausedKvTokens};
19021910
}
19031911

19041912
void Serialization::serialize(InflightBatchingStats const& inflightBatchingStats, std::ostream& os)
@@ -1910,6 +1918,13 @@ void Serialization::serialize(InflightBatchingStats const& inflightBatchingStats
19101918
su::serialize(inflightBatchingStats.numCtxTokens, os);
19111919
su::serialize(inflightBatchingStats.microBatchId, os);
19121920
su::serialize(inflightBatchingStats.avgNumDecodedTokensPerIter, os);
1921+
su::serialize(inflightBatchingStats.numCtxKvTokens, os);
1922+
su::serialize(inflightBatchingStats.numGenKvTokens, os);
1923+
su::serialize(inflightBatchingStats.numQueuedContextRequests, os);
1924+
su::serialize(inflightBatchingStats.numQueuedCtxTokens, os);
1925+
su::serialize(inflightBatchingStats.numQueuedGenRequests, os);
1926+
su::serialize(inflightBatchingStats.numQueuedGenKvTokens, os);
1927+
su::serialize(inflightBatchingStats.numPausedKvTokens, os);
19131928
}
19141929

19151930
size_t Serialization::serializedSize(InflightBatchingStats const& inflightBatchingStats)
@@ -1922,6 +1937,13 @@ size_t Serialization::serializedSize(InflightBatchingStats const& inflightBatchi
19221937
totalSize += su::serializedSize(inflightBatchingStats.numCtxTokens);
19231938
totalSize += su::serializedSize(inflightBatchingStats.microBatchId);
19241939
totalSize += su::serializedSize(inflightBatchingStats.avgNumDecodedTokensPerIter);
1940+
totalSize += su::serializedSize(inflightBatchingStats.numCtxKvTokens);
1941+
totalSize += su::serializedSize(inflightBatchingStats.numGenKvTokens);
1942+
totalSize += su::serializedSize(inflightBatchingStats.numQueuedContextRequests);
1943+
totalSize += su::serializedSize(inflightBatchingStats.numQueuedCtxTokens);
1944+
totalSize += su::serializedSize(inflightBatchingStats.numQueuedGenRequests);
1945+
totalSize += su::serializedSize(inflightBatchingStats.numQueuedGenKvTokens);
1946+
totalSize += su::serializedSize(inflightBatchingStats.numPausedKvTokens);
19251947
return totalSize;
19261948
}
19271949

cpp/tensorrt_llm/nanobind/executor/bindings.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,14 @@ void initBindings(nb::module_& m)
131131
.def_rw("num_paused_requests", &tle::InflightBatchingStats::numPausedRequests)
132132
.def_rw("num_ctx_tokens", &tle::InflightBatchingStats::numCtxTokens)
133133
.def_rw("micro_batch_id", &tle::InflightBatchingStats::microBatchId)
134-
.def_rw("avg_num_decoded_tokens_per_iter", &tle::InflightBatchingStats::avgNumDecodedTokensPerIter);
134+
.def_rw("avg_num_decoded_tokens_per_iter", &tle::InflightBatchingStats::avgNumDecodedTokensPerIter)
135+
.def_rw("num_ctx_kv_tokens", &tle::InflightBatchingStats::numCtxKvTokens)
136+
.def_rw("num_gen_kv_tokens", &tle::InflightBatchingStats::numGenKvTokens)
137+
.def_rw("num_queued_context_requests", &tle::InflightBatchingStats::numQueuedContextRequests)
138+
.def_rw("num_queued_ctx_tokens", &tle::InflightBatchingStats::numQueuedCtxTokens)
139+
.def_rw("num_queued_gen_requests", &tle::InflightBatchingStats::numQueuedGenRequests)
140+
.def_rw("num_queued_gen_kv_tokens", &tle::InflightBatchingStats::numQueuedGenKvTokens)
141+
.def_rw("num_paused_kv_tokens", &tle::InflightBatchingStats::numPausedKvTokens);
135142

136143
nb::class_<tle::SpecDecodingStats>(m, "SpecDecodingStats")
137144
.def(nb::init<>())

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
FinishReason, InflightBatchingStats,
2828
IterationStats, KvCacheStats,
2929
RequestStage, RequestStats,
30-
SpecDecodingStats,
30+
RequestType, SpecDecodingStats,
3131
StaticBatchingStats)
3232
from tensorrt_llm.bindings.internal.batch_manager import (LlmRequestType,
3333
ReqIdsSet)
@@ -1198,6 +1198,108 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
11981198
# Calculate draft overhead
11991199
stats.specdec_stats.draft_overhead = 0.0 if iter_latency_ms <= 0.0 else float(
12001200
draft_latency_ms) / float(iter_latency_ms)
1201+
1202+
# Extra per-iteration request-aggregate counters attached to
1203+
# inflight_batching_stats. These complement the existing
1204+
# num_context_requests / num_gen_requests / num_ctx_tokens /
1205+
# num_paused_requests members with token-weighted counts and
1206+
# queue/paused KV accounting.
1207+
1208+
# Tokens read from prior state (prefix-cache hits and
1209+
# previously-chunked tokens) summed across scheduled context
1210+
# requests; complements num_ctx_tokens (tokens computed this
1211+
# iteration). Read from py_last_context_chunk, a Python-side
1212+
# cache set by _update_request_states before state mutation — it
1213+
# stays valid after the request transitions to
1214+
# GENERATION_IN_PROGRESS, unlike the C++ getContextChunkSize() /
1215+
# getContextCurrentPosition() accessors that would raise
1216+
# RuntimeError on a mutated request.
1217+
num_ctx_kv_tokens = 0
1218+
for req in scheduled_batch.context_requests:
1219+
if getattr(req, "is_attention_dp_dummy", False):
1220+
continue
1221+
last_chunk = getattr(req, "py_last_context_chunk", None)
1222+
if last_chunk is not None and last_chunk[0] is not None:
1223+
start, _end = last_chunk
1224+
num_ctx_kv_tokens += start
1225+
else:
1226+
try:
1227+
num_ctx_kv_tokens += \
1228+
req.context_current_position
1229+
except RuntimeError:
1230+
pass
1231+
1232+
# Total KV context length (prompt + tokens generated so far)
1233+
# summed across scheduled generation requests.
1234+
num_gen_kv_tokens = 0
1235+
for req in scheduled_batch.generation_requests:
1236+
if getattr(req, "is_attention_dp_dummy", False):
1237+
continue
1238+
try:
1239+
num_gen_kv_tokens += req.get_num_tokens(0)
1240+
except RuntimeError:
1241+
pass
1242+
1243+
# Normal requests waiting in the executor_request_queue that have
1244+
# never been scheduled. Excludes non-normal control items
1245+
# (shutdown/cancel) and items with a missing payload. Each queued
1246+
# item is a RequestQueueItem wrapping an ExecutorRequest
1247+
# (tle::Request). Requests are routed by request_type:
1248+
# - CONTEXT_AND_GENERATION (default) and CONTEXT_ONLY
1249+
# (disagg-prefill side) -> queued-context counters.
1250+
# - GENERATION_ONLY (disagg-decode side, awaiting KV transfer
1251+
# before they can start decoding) -> queued-gen counters.
1252+
# On a non-disagg engine all items land in the context counters;
1253+
# on a disagg-decode engine all items land in the gen counters.
1254+
num_queued_context_requests = 0
1255+
num_queued_ctx_tokens = 0
1256+
num_queued_gen_requests = 0
1257+
num_queued_gen_kv_tokens = 0
1258+
for item in list(self.executor_request_queue.get_request_queue().queue):
1259+
if not item.is_normal_request:
1260+
continue
1261+
if item.request is None:
1262+
continue
1263+
try:
1264+
token_count = len(item.request.input_token_ids)
1265+
except (AttributeError, TypeError) as e:
1266+
# Unusual request shape with no usable token payload;
1267+
# exclude from all queued counters so downstream consumers
1268+
# see consistent per-request averages. Not expected on the
1269+
# current API (ExecutorRequest construction requires a
1270+
# non-empty input_token_ids), logged so future API drift
1271+
# surfaces instead of being silently dropped.
1272+
logger.warning(f"Excluding queued item {item.id} from queued "
1273+
f"counters: input_token_ids not readable "
1274+
f"({type(e).__name__})")
1275+
continue
1276+
if item.request.request_type == RequestType.REQUEST_TYPE_GENERATION_ONLY:
1277+
num_queued_gen_requests += 1
1278+
num_queued_gen_kv_tokens += token_count
1279+
else:
1280+
num_queued_context_requests += 1
1281+
num_queued_ctx_tokens += token_count
1282+
1283+
# Total KV context length summed across paused (preempted-decode)
1284+
# requests — were decoding but got evicted back to the waiting
1285+
# pool for this iteration.
1286+
num_paused_kv_tokens = 0
1287+
for req in scheduled_batch.paused_requests:
1288+
if getattr(req, "is_attention_dp_dummy", False):
1289+
continue
1290+
try:
1291+
num_paused_kv_tokens += req.get_num_tokens(0)
1292+
except RuntimeError:
1293+
pass
1294+
1295+
stats.inflight_batching_stats.num_ctx_kv_tokens = num_ctx_kv_tokens
1296+
stats.inflight_batching_stats.num_gen_kv_tokens = num_gen_kv_tokens
1297+
stats.inflight_batching_stats.num_queued_context_requests = num_queued_context_requests
1298+
stats.inflight_batching_stats.num_queued_ctx_tokens = num_queued_ctx_tokens
1299+
stats.inflight_batching_stats.num_queued_gen_requests = num_queued_gen_requests
1300+
stats.inflight_batching_stats.num_queued_gen_kv_tokens = num_queued_gen_kv_tokens
1301+
stats.inflight_batching_stats.num_paused_kv_tokens = num_paused_kv_tokens
1302+
12011303
return stats
12021304

12031305
def _append_iter_stats(self,

tensorrt_llm/executor/base_worker.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,15 @@ def get_disaggregated_params(self) -> dict:
664664
def _stats_serializer(stats) -> str:
665665
iteration_stats, req_stats = stats[0], stats[1]
666666
kv_iter_stats = stats[2] if len(stats) > 2 else None
667+
667668
stats_dict = json.loads(iteration_stats.to_json_str())
669+
# Tag with dp_rank=0 so Dynamo's adapter can always read
670+
# stat["attentionDpRank"] without a missing-key branch. Attention-DP
671+
# per-rank emission is a follow-up; today FPM only flows under
672+
# non-attention-DP.
673+
# TODO(https://jirasw.nvidia.com/browse/TRTLLM-12123): implement
674+
# per-rank IterationStats delivery under attention-DP.
675+
stats_dict.setdefault("attentionDpRank", 0)
668676

669677
if req_stats is not None and len(req_stats) > 0:
670678
stats_dict["requestStats"] = []

tests/unittest/_torch/executor/test_pytorch_model_engine.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,96 @@ def test_pad_generation_requests(self) -> None:
191191

192192
kv_cache_manager.shutdown()
193193

194+
def test_pad_batch_strips_cudagraph_dummies_on_clean_exit(self) -> None:
195+
# Regression guard for the invariant that CUDAGraphRunner.pad_batch's
196+
# `finally` strips every is_cuda_graph_dummy=True entry from
197+
# scheduled_requests.generation_requests before the `with` block
198+
# exits. Downstream consumers of scheduled_batch.generation_requests
199+
# — including the per-iteration stats populate block in
200+
# PyExecutor._update_iter_stats — rely on never observing
201+
# cudagraph dummies.
202+
model_engine, kv_cache_manager = create_model_engine_and_kvcache()
203+
resource_manager = ResourceManager(
204+
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
205+
206+
# batch_size=5 rounds up to 8 (nearest captured graph size in the
207+
# fixture config) -> padding_size=3, deterministically.
208+
real_batch_size = 5
209+
max_seq_len = 1
210+
real_requests = [
211+
_create_request(max_seq_len, i) for i in range(real_batch_size)
212+
]
213+
real_ids = [id(r) for r in real_requests]
214+
215+
batch = ScheduledRequests()
216+
batch.generation_requests = list(real_requests)
217+
218+
with model_engine.cuda_graph_runner.pad_batch(
219+
batch, resource_manager) as padded_batch:
220+
# Positive assertion that padding actually fired — guards
221+
# against a vacuous pass where padding was a no-op.
222+
self.assertGreater(
223+
len(padded_batch.generation_requests), real_batch_size,
224+
"padding did not fire; fixture config may have drifted "
225+
"so that 5 no longer rounds up to 8")
226+
# Every appended entry past the original count is a
227+
# cudagraph-flagged dummy.
228+
for req in padded_batch.generation_requests[real_batch_size:]:
229+
self.assertTrue(
230+
getattr(req, "is_cuda_graph_dummy", False),
231+
"pad_batch appended a request without "
232+
"is_cuda_graph_dummy=True")
233+
# Real requests' identities and order are untouched.
234+
self.assertEqual([
235+
id(r)
236+
for r in padded_batch.generation_requests[:real_batch_size]
237+
], real_ids)
238+
239+
# After the with-block: finally must have sliced off the padding.
240+
self.assertEqual(
241+
len(batch.generation_requests), real_batch_size,
242+
"pad_batch.finally did not strip cudagraph dummies — "
243+
"downstream consumers of scheduled_batch.generation_requests "
244+
"would observe the leaked dummies")
245+
for req in batch.generation_requests:
246+
self.assertFalse(
247+
getattr(req, "is_cuda_graph_dummy", False),
248+
"cudagraph dummy leaked out of pad_batch's finally")
249+
250+
kv_cache_manager.shutdown()
251+
252+
def test_pad_batch_strips_cudagraph_dummies_on_exception(self) -> None:
253+
# The strip must fire even when the body raises. This is the
254+
# critical property of `finally` vs. a plain trailing statement —
255+
# it guards the invariant on the error path. A refactor that
256+
# accidentally dropped the `finally` would be caught here but not
257+
# by the clean-exit variant.
258+
model_engine, kv_cache_manager = create_model_engine_and_kvcache()
259+
resource_manager = ResourceManager(
260+
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
261+
262+
real_batch_size = 5
263+
real_requests = [_create_request(1, i) for i in range(real_batch_size)]
264+
265+
batch = ScheduledRequests()
266+
batch.generation_requests = list(real_requests)
267+
268+
class _ForwardBoom(Exception):
269+
pass
270+
271+
with self.assertRaises(_ForwardBoom):
272+
with model_engine.cuda_graph_runner.pad_batch(
273+
batch, resource_manager) as padded_batch:
274+
self.assertGreater(len(padded_batch.generation_requests),
275+
real_batch_size)
276+
raise _ForwardBoom()
277+
278+
self.assertEqual(len(batch.generation_requests), real_batch_size)
279+
for req in batch.generation_requests:
280+
self.assertFalse(getattr(req, "is_cuda_graph_dummy", False))
281+
282+
kv_cache_manager.shutdown()
283+
194284
def test_position_id_preparation(self):
195285
model_engine, kv_cache_manager = create_model_engine_and_kvcache()
196286
resource_manager = ResourceManager(

0 commit comments

Comments
 (0)