Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
896bec6
[TRTLLM-12622][feat] Add native post-processing hook to trtllm-serve
xwang233 Jun 10, 2026
5a7dbba
[TRTLLM-12622][test] Add e2e tests for trtllm-serve post-processing hook
xwang233 Jun 10, 2026
11c5f23
[TRTLLM-12622][fix] Drop late proxy responses after post-processor ho…
xwang233 Jun 10, 2026
1146de6
[TRTLLM-12622][doc] Document the trtllm-serve post-processing hook
xwang233 Jun 10, 2026
b46c4e7
[TRTLLM-12622][chore] Address post-processor hook review feedback
xwang233 Jun 11, 2026
959f2da
[TRTLLM-12622][chore] Fail-fast hook registration + review fixes
xwang233 Jun 15, 2026
278da1f
[TRTLLM-12622][refactor] Scope post-processing hook to the LLM instance
xwang233 Jun 15, 2026
696fa4e
Merge remote-tracking branch 'origin/main' into trtllm-12622-post-pro…
xwang233 Jun 15, 2026
4068a2d
[TRTLLM-12622][chore] Trim verbose comments and a redundant test
xwang233 Jun 15, 2026
e20bbd1
[TRTLLM-12622][fix] Run post-processing hook regardless of detokenize…
xwang233 Jun 15, 2026
3699efe
[TRTLLM-12622][fix] Withhold all client channels on suppress/terminate
xwang233 Jun 15, 2026
4582a2d
[TRTLLM-12622][doc] Clarify per-chunk withholding semantics
xwang233 Jun 15, 2026
aa3a324
[TRTLLM-12622][fix] Fail closed on hook errors; enum verdict actions
xwang233 Jun 15, 2026
f66eb4d
[TRTLLM-12622][test] Register post-processor hook unit tests in L0 A10
xwang233 Jun 16, 2026
4dbe9f0
Merge remote-tracking branch 'origin/main' into trtllm-12622-post-pro…
xwang233 Jun 22, 2026
689c8db
[TRTLLM-12622][fix] Validate hook is callable at load time
xwang233 Jun 22, 2026
93793de
[TRTLLM-12622][chore] Rename post-processor hook public types, de-dup…
xwang233 Jun 23, 2026
d4c096f
Merge remote-tracking branch 'origin/main' into trtllm-12622-post-pro…
xwang233 Jun 23, 2026
5fd660e
[TRTLLM-12622][chore] Rename --post_processor flag/field to --post_pr…
xwang233 Jun 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions docs/source/features/post-processor-hook.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
(post-processor-hook)=

# Post-Processing Hook

`trtllm-serve` supports a user-supplied **post-processing hook**: a native, per-request seam that runs
on each generated output *after* detokenization and *before* the per-endpoint response formatter. It
lets a deployment rewrite, redact, suppress, or terminate model output — including stateful logic that
spans the chunks of a streamed response — without modifying TensorRT LLM source.

The hook is a plain Python callable class supplied by import path, mirroring `--custom_tokenizer`. It
is owned by the `LLM` instance (and built once in each post-processing worker process when those are
enabled) and invoked once per output, per streaming chunk (plus a final call), so it can hold its own
per-request state. Independent `LLM` instances in one process each own a separate hook instance.

```{note}
This feature is a prototype and its interface may change in a future release.
```

For the interface definitions referenced below, see
[`tensorrt_llm/executor/postprocessor_hook.py`](../../../tensorrt_llm/executor/postprocessor_hook.py).

## Enabling the hook

Pass the dotted import path of your hook class to `--post_processor_hook`:

```bash
trtllm-serve <model> --post_processor_hook my_pkg.guardrail.MyPostProcessorHook
```

Equivalently, set it in a YAML config passed via `--extra_llm_api_options`:

```yaml
post_processor_hook: my_pkg.guardrail.MyPostProcessorHook
```

The class must be:

- **Importable** — installed (`pip install`) or on `PYTHONPATH` when the server (and its
post-processing worker processes, if any) start.
- **Picklable and no-argument-constructible** — the hook is reconstructed by reference inside each
process; `__init__` takes no required arguments and is the one-time setup point.

The hook works the same with or without the out-of-process post-processing worker pool
(`--num_postprocess_workers`).

## The hook interface

A hook implements a single method, `__call__(self, chunk) -> verdict`:

```python
from tensorrt_llm.executor.postprocessor_hook import (
PostProcessorHookChunk,
PostProcessorHookVerdict,
emit,
suppress,
terminate,
)


class MyPostProcessorHook:
def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict:
return emit(chunk.text_diff) # pass through unchanged
```

### `PostProcessorHookChunk`

The payload handed to the hook for one output chunk — `request_id`, `output_index`, `text_diff`,
`text`, `token_ids_diff`, `is_final`, `aborted`, and `streaming`. See the `PostProcessorHookChunk`
dataclass in [`postprocessor_hook.py`](../../../tensorrt_llm/executor/postprocessor_hook.py)
for the authoritative field-by-field descriptions.

### Verdicts

Return one of the following helpers from `__call__`:

| Helper | Effect |
|--------|--------|
| `emit(text)` | Emit `text` for this chunk. Use it to pass output through unchanged (`emit(chunk.text_diff)`) or to rewrite/redact it. Affects the **text** channel only — it does not synthesize matching `token_ids` (see *Text vs. token ids* below). |
| `suppress()` | Withhold this chunk entirely across **all** client-visible channels — text, `token_ids`, and `logprobs` (so `detokenize=false` token output is withheld too). |
| `terminate(reason)` | Stop the stream for this request, withholding the terminating chunk on all channels. `reason` is surfaced as the response `stop_reason`, and the engine request is cancelled. |

Verdicts are **per chunk**: `suppress()` withholds the current chunk, and `terminate()` stops generation while keeping the chunks already emitted before it. This is consistent across streaming and non-streaming — a non-streaming response contains exactly the content the hook emitted before the first `suppress`/`terminate`, i.e. the same content a streaming client would have received. A hook that must withhold the *entire* output (all-or-nothing) should `suppress()` from the first chunk (it sees every chunk) rather than emitting and then terminating.

## Usage examples

### Rewrite output

A stateless hook that upper-cases every chunk:

```python
from tensorrt_llm.executor.postprocessor_hook import PostProcessorHookChunk, PostProcessorHookVerdict, emit


class UpperCaseHook:
def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict:
return emit(chunk.text_diff.upper())
```

### Stateful guardrail that terminates on a banned phrase

This hook accumulates text per request (keyed by `request_id`), stops the stream as soon as a banned
phrase appears, and releases its state when the request finishes:

```python
from tensorrt_llm.executor.postprocessor_hook import (
PostProcessorHookChunk, PostProcessorHookVerdict, emit, terminate,
)


class BannedPhraseGuard:
_BANNED = ("forbidden phrase",)

def __init__(self):
# Per-request accumulators owned entirely by the hook.
self._buffers: dict[int, str] = {}

def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict:
buffer = self._buffers.get(chunk.request_id, "") + chunk.text_diff.lower()
self._buffers[chunk.request_id] = buffer

if any(phrase in buffer for phrase in self._BANNED):
self._buffers.pop(chunk.request_id, None)
return terminate("banned_phrase")

if chunk.is_final:
self._buffers.pop(chunk.request_id, None)
return emit(chunk.text_diff)
```

### Suppress output

A hook that withholds all client-visible text:

```python
from tensorrt_llm.executor.postprocessor_hook import PostProcessorHookChunk, PostProcessorHookVerdict, suppress


class SuppressHook:
def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict:
return suppress()
```

## Per-request state

The hook instance is owned by the `LLM` (built once in each post-processing worker process when the
pool is enabled) and shared across all requests it handles, so any per-request state must be keyed by
`chunk.request_id` and released when `chunk.is_final` is seen (or after a `terminate`). State is not
shared across processes or across separate `LLM` instances; when the post-processing worker pool is
enabled, all chunks of a single request are still routed to the same worker, so per-request state
remains consistent for that request.

Engine-level batching is transparent to the hook: even when requests are batched together in the
engine, the hook is invoked **once per request** (per output, per chunk) — there is no batched-call
form — so keying state on `chunk.request_id` is sufficient to keep concurrent requests isolated.

## Supported endpoints and limitations

- **Endpoints**: `chat/completions` and `completions`, both streaming and non-streaming. The hook is
expected to apply to the `responses` endpoint for non-harmony models as well (shared detokenization
path), though that endpoint is not covered by the current end-to-end tests.
- **Not client-bypassable**: the hook is a server-side guardrail, so it runs on every response even
when a `completions` request sets `detokenize=false`. The server detokenizes for the hook regardless;
the `detokenize` flag still controls only the returned channel (text vs. token ids). A
`suppress`/`terminate` verdict withholds **all** client-visible channels — text, `token_ids`, and
`logprobs` — on both the streaming and non-streaming paths, so a client cannot recover withheld
content through any channel.
- **Requires a tokenizer**: the hook needs detokenized text to inspect, so `--post_processor_hook` combined
with `skip_tokenizer_init` is rejected at startup rather than silently disabled.
- **harmony / gpt-oss models**: not supported. Because the harmony output path is reconstructed from
raw token ids, it would bypass the text-based hook, so the server fails fast at startup when
`--post_processor_hook` is combined with a harmony model.
- **Disaggregated serving**: the context and generation servers are separate processes, each running
the hook on its own phase under a different `request_id`; per-request state cannot be correlated
across the two. A `terminate` on one phase does not propagate to the other.
- **Text vs. token ids**: `emit` rewrites the **text** channel only — it does not rewrite the underlying
`token_ids`/`logprobs`, so a client reading both should expect them to diverge. (`suppress`/`terminate`
withhold all channels, so they do not diverge.)
- **Reasoning / tool parsing**: the hook runs before the reasoning and tool-call parsers. A hook that
rewrites or suppresses text may desync those parsers; prefer `terminate`, or apply such hooks to
plain-text requests.
- **Hook errors fail closed**: if the hook raises, the request fails with an error rather than serving
the un-vetted chunk (the server and other requests stay alive). A returned verdict with an unknown
action is rejected the same way.
- **`n` > 1 / beam search**: `emit` and `suppress` act per output sequence, but `terminate` cancels the
**whole** request (all sequences), because the engine request is the unit of cancellation — a
`terminate` on one candidate ends the others too. Hooks needing per-sequence state should key on
`(request_id, output_index)`.

## Tests

The hook's unit and end-to-end tests double as runnable usage examples:

```bash
# Unit tests for the hook contract (rewrite / suppress / terminate, per-request state, loader)
pytest tests/unittest/executor/test_postprocessor_hook.py -v
```

- End-to-end serving tests across endpoints, streaming modes, and worker-pool settings:
`tests/unittest/llmapi/apps/_test_openai_post_processor.py`.
- The deterministic sample hooks used by those tests:
`tests/unittest/llmapi/apps/_postproc_hook_samples.py`.
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Welcome to TensorRT LLM's Documentation!
features/quantization.md
features/sampling.md
features/additional-outputs.md
features/post-processor-hook.md
features/guided-decoding.md
features/speculative-decoding.md
features/checkpoint-loading.md
Expand Down
28 changes: 23 additions & 5 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def is_non_default_or_required(param_name, value, backend, explicit_cli_keys):
"""
always_include = {
"model", "backend", "tokenizer", "custom_tokenizer",
"postprocess_tokenizer_dir"
"post_processor_hook", "postprocess_tokenizer_dir"
}

if param_name in always_include:
Expand Down Expand Up @@ -181,6 +181,7 @@ def get_llm_args(
model: str,
tokenizer: Optional[str] = None,
custom_tokenizer: Optional[str] = None,
post_processor_hook: Optional[str] = None,
backend: str = "pytorch",
max_beam_width: int = BuildConfig.model_fields["max_beam_width"].
default,
Expand Down Expand Up @@ -238,6 +239,8 @@ def get_llm_args(
tokenizer,
"custom_tokenizer":
custom_tokenizer,
"post_processor_hook":
post_processor_hook,
"postprocess_tokenizer_dir":
tokenizer or model,
"kv_cache_config":
Expand Down Expand Up @@ -634,6 +637,19 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
"Custom tokenizer type: alias (e.g., 'deepseek_v32') or Python import path "
"(e.g., 'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer').",
"prototype"))
@click.option(
"--post_processor_hook",
type=str,
default=None,
help=help_info_with_stability_tag(
"Python import path of a user post-processing hook applied after "
"detokenization and before the per-endpoint response formatter (e.g. "
"'my_pkg.guardrail.MyPostProcessorHook'). The class must be importable "
"and picklable, take no constructor arguments, and be callable as "
"'__call__(chunk) -> verdict' (see tensorrt_llm.executor.postprocessor_hook). "
"It runs once per output, per streaming chunk, and may rewrite, "
"suppress, or terminate the output; it owns its own per-request state.",
"prototype"))
@click.option("--host",
type=str,
default="localhost",
Expand Down Expand Up @@ -912,10 +928,11 @@ def convert(self, value: Any, param: Optional["click.Parameter"],
"Types of agents to schedule. Now Only Support Open Deep Research agent.")
def serve(
model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str],
host: str, port: int, log_level: str, backend: str, max_beam_width: int,
max_batch_size: int, max_num_tokens: int, max_seq_len: int,
tensor_parallel_size: int, pipeline_parallel_size: int,
context_parallel_size: int, moe_expert_parallel_size: Optional[int],
post_processor_hook: Optional[str], host: str, port: int,
log_level: str, backend: str, max_beam_width: int, max_batch_size: int,
max_num_tokens: int, max_seq_len: int, tensor_parallel_size: int,
pipeline_parallel_size: int, context_parallel_size: int,
moe_expert_parallel_size: Optional[int],
moe_cluster_parallel_size: Optional[int], gpus_per_node: Optional[int],
free_gpu_memory_fraction: float, kv_cache_dtype: str,
num_postprocess_workers: int, trust_remote_code: bool,
Expand Down Expand Up @@ -996,6 +1013,7 @@ def _serve_llm():
model=model,
tokenizer=tokenizer,
custom_tokenizer=custom_tokenizer,
post_processor_hook=post_processor_hook,
backend=backend,
max_beam_width=max_beam_width,
max_batch_size=max_batch_size,
Expand Down
32 changes: 29 additions & 3 deletions tensorrt_llm/executor/postproc_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from ..logger import logger
from ..sampling_params import SamplingParams
from .ipc import ZeroMqQueue
from .postprocessor_hook import load_post_processor_hook
from .utils import ErrorResponse, is_llm_response

if TYPE_CHECKING:
Expand Down Expand Up @@ -46,6 +47,10 @@ class PostprocWorkerConfig:
''' The config for the postprocess worker. '''
num_postprocess_workers: int = 0
postprocess_tokenizer_dir: Optional[str] = None
# Dotted import path of the user post-processing hook, or
# None. NOTE: distinct from ``PostprocParams.post_processor``, which is the
# per-endpoint response *formatter* (a Callable), not this hook.
post_processor_hook: Optional[str] = None

@property
def enabled(self) -> bool:
Expand Down Expand Up @@ -85,6 +90,7 @@ def __init__(
tokenizer_dir: str,
record_creator: Callable[
["PostprocWorker.Input", TransformersTokenizer], Any],
post_processor_hook: Optional[str] = None,
):
'''
Args:
Expand All @@ -93,6 +99,7 @@ def __init__(
tokenizer_dir (str): The directory to load tokenizer.
record_creator (Callable[["ResponsePostprocessWorker.Input"], Any]): A creator for creating a record for a request.
result_handler (Optional[Callable[[GenerationResultBase], Any]]): A callback handles the final result.
post_processor_hook (Optional[str]): Import path of the user post-processing hook; built once and threaded onto each record.
'''

self._records: Dict[int, GenerationResult] = {}
Expand All @@ -113,6 +120,12 @@ def __init__(
# Load the tokenizer and share in all records
self._tokenizer = load_hf_tokenizer(tokenizer_dir)

# Build the user post-processing hook once, like the
# tokenizer above; threaded onto each record in ``_handle_input``.
self._post_processor_hook = (
load_post_processor_hook(post_processor_hook)
if post_processor_hook else None)

@staticmethod
def default_record_creator(
inp: "PostprocWorker.Input", tokenizer: TransformersTokenizer
Expand Down Expand Up @@ -144,6 +157,10 @@ async def _handle_input(
# TODO: support variant creation later
self._records[req_id] = self._record_creator(
input, self._tokenizer)
# Thread the hook onto the record here rather than
# via record_creator, so custom record_creators keep working.
self._records[
req_id]._post_processor_hook = self._post_processor_hook
if input.disaggregated_params is not None:
self._records[
req_id]._disaggregated_params = input.disaggregated_params
Expand Down Expand Up @@ -202,6 +219,11 @@ async def handle_single_input(inp: PostprocWorker.Input,
res, metrics, perf_metrics, disaggregated_params = await self._handle_input(
inp)
record = self._records.get(client_id)
# A `terminate` verdict forces the record done;
# honor it so the stream stops and the record is popped without
# waiting for the engine's own is_final.
if record is not None and record._done:
is_final = True
should_abort = record._aborted if record else False
finish_reason = record.outputs[0].finish_reason if (
record and record.outputs
Expand All @@ -221,7 +243,7 @@ async def handle_single_input(inp: PostprocWorker.Input,
num_generated_tokens=num_generated_tokens,
))
if is_final:
self._records.pop(client_id)
self._records.pop(client_id, None)
except Exception as e:
logger.error(
f"Postprocessing error for client {client_id}: {e}\n"
Expand Down Expand Up @@ -268,9 +290,13 @@ async def main():
@print_traceback_on_error
def postproc_worker_main(feedin_ipc_addr: tuple[str, Optional[bytes]],
feedout_ipc_addr: tuple[str, Optional[bytes]],
tokenizer_dir: str, record_creator: Callable):
tokenizer_dir: str,
record_creator: Callable,
post_processor_hook: Optional[str] = None):
# Pass the hook import path; PostprocWorker builds it once.
worker = PostprocWorker(feedin_ipc_addr,
feedout_ipc_addr,
tokenizer_dir=tokenizer_dir,
record_creator=record_creator)
record_creator=record_creator,
post_processor_hook=post_processor_hook)
worker.start()
Loading
Loading