From 896bec677387cd29d449ab2335789e10f173ef1c Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 10 Jun 2026 13:56:03 -0700 Subject: [PATCH 01/16] [TRTLLM-12622][feat] Add native post-processing hook to trtllm-serve Add a user-pluggable, per-request, stateful post-processing hook for trtllm-serve, equivalent to a Triton python-backend post-processor. The hook runs after detokenization and before the per-endpoint response formatter, and may rewrite, suppress, or terminate the streamed output. - New executor/postprocessor_hook.py: PostProcChunk / PostProcVerdict, emit/suppress/terminate, the PostProcessorHook protocol, an import-path loader with a process build-once cache, and apply_post_processor_hook. - Single chokepoint in DetokenizedGenerationResultBase._handle_response, covering both the postproc-worker path and the in-proxy RequestOutput path. The hook path is configured per process (set in BaseLLM.__init__ and postproc_worker_main) via a new PostprocWorkerConfig.post_processor_hook. - Config surface: BaseLlmArgs.post_processor (prototype) plus a trtllm-serve --post_processor flag (also settable via --extra_llm_api_options). - terminate cancels the engine via result.abort() and forces the worker record done; suppress/terminate withhold the raw token-id/logprob channel too; hook exceptions are isolated per request (fail-open + logged). - Reject --post_processor with harmony/gpt-oss models at startup (the harmony path rebuilds output from raw tokens and would bypass the hook). - Unit tests for verdict semantics, per-request state isolation, the loader, and the process global; api_stability reference updated. The hook is text-based and operates post-detok; a hook that rewrites or suppresses text may desync stateful reasoning/tool parsers, and rewriting text does not rewrite the underlying token ids. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- tensorrt_llm/commands/serve.py | 26 +- tensorrt_llm/executor/postproc_worker.py | 20 +- tensorrt_llm/executor/postprocessor_hook.py | 253 ++++++++++++++++++ tensorrt_llm/executor/result.py | 19 ++ tensorrt_llm/executor/worker.py | 1 + tensorrt_llm/llmapi/llm.py | 16 ++ tensorrt_llm/llmapi/llm_args.py | 12 + tensorrt_llm/serve/openai_server.py | 13 + .../api_stability/references/llm.yaml | 4 + .../executor/test_postprocessor_hook.py | 247 +++++++++++++++++ 10 files changed, 605 insertions(+), 6 deletions(-) create mode 100644 tensorrt_llm/executor/postprocessor_hook.py create mode 100644 tests/unittest/executor/test_postprocessor_hook.py diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 629dfa2d6a5a..cfa7ff3457c7 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -131,7 +131,7 @@ def is_non_default_or_required(param_name, value, backend, explicit_cli_keys): 3. Different from its default value in the backend's LlmArgs class """ always_include = { - "model", "backend", "tokenizer", "custom_tokenizer", + "model", "backend", "tokenizer", "custom_tokenizer", "post_processor", "postprocess_tokenizer_dir" } @@ -181,6 +181,7 @@ def get_llm_args( model: str, tokenizer: Optional[str] = None, custom_tokenizer: Optional[str] = None, + post_processor: Optional[str] = None, backend: str = "pytorch", max_beam_width: int = BuildConfig.model_fields["max_beam_width"]. default, @@ -238,6 +239,8 @@ def get_llm_args( tokenizer, "custom_tokenizer": custom_tokenizer, + "post_processor": + post_processor, "postprocess_tokenizer_dir": tokenizer or model, "kv_cache_config": @@ -617,6 +620,17 @@ def convert(self, value: Any, param: Optional["click.Parameter"], "Custom tokenizer type: alias (e.g., 'deepseek_v32') or Python import path " "(e.g., 'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer').", "prototype")) +@click.option( + "--post_processor", + type=str, + default=None, + help=help_info_with_stability_tag( + "Python import path of a user post-processing hook applied after " + "detokenization and before the response formatter (e.g. " + "'my_pkg.guardrail.MyPostProcessor'). The class must be importable, " + "picklable, take no constructor arguments, and be callable per chunk; " + "it may rewrite, suppress, or terminate the output and owns its own " + "per-request state.", "prototype")) @click.option("--host", type=str, default="localhost", @@ -897,10 +911,11 @@ def convert(self, value: Any, param: Optional["click.Parameter"], "Types of agents to schedule. Now Only Support Open Deep Research agent.") def serve( model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str], - host: str, port: int, log_level: str, backend: str, max_beam_width: int, - max_batch_size: int, max_num_tokens: int, max_seq_len: int, - tensor_parallel_size: int, pipeline_parallel_size: int, - context_parallel_size: int, moe_expert_parallel_size: Optional[int], + post_processor: Optional[str], host: str, port: int, log_level: str, + backend: str, max_beam_width: int, max_batch_size: int, + max_num_tokens: int, max_seq_len: int, tensor_parallel_size: int, + pipeline_parallel_size: int, context_parallel_size: int, + moe_expert_parallel_size: Optional[int], moe_cluster_parallel_size: Optional[int], gpus_per_node: Optional[int], free_gpu_memory_fraction: float, kv_cache_dtype: str, num_postprocess_workers: int, trust_remote_code: bool, @@ -981,6 +996,7 @@ def _serve_llm(): model=model, tokenizer=tokenizer, custom_tokenizer=custom_tokenizer, + post_processor=post_processor, backend=backend, max_beam_width=max_beam_width, max_batch_size=max_batch_size, diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index 4b7200d2238d..069f445b0725 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -46,6 +46,11 @@ class PostprocWorkerConfig: ''' The config for the postprocess worker. ''' num_postprocess_workers: int = 0 postprocess_tokenizer_dir: Optional[str] = None + # Dotted import path of the user post-processing hook (TRTLLM-12622), or + # None. Propagated into each postproc worker process so the detok chokepoint + # can apply it. NOTE: distinct from ``PostprocParams.post_processor`` above, + # which is the per-endpoint response *formatter* (a Callable), not this hook. + post_processor_hook: Optional[str] = None @property def enabled(self) -> bool: @@ -202,6 +207,12 @@ async def handle_single_input(inp: PostprocWorker.Input, res, metrics, perf_metrics, disaggregated_params = await self._handle_input( inp) record = self._records.get(client_id) + # A post-processing hook that returned `terminate` forces the + # record done (TRTLLM-12622); honor it so the stream stops + # promptly and the record is popped, instead of waiting for the + # engine's own is_final. + if record is not None and record._done: + is_final = True should_abort = record._aborted if record else False finish_reason = record.outputs[0].finish_reason if ( record and record.outputs @@ -268,7 +279,14 @@ async def main(): @print_traceback_on_error def postproc_worker_main(feedin_ipc_addr: tuple[str, Optional[bytes]], feedout_ipc_addr: tuple[str, Optional[bytes]], - tokenizer_dir: str, record_creator: Callable): + tokenizer_dir: str, + record_creator: Callable, + post_processor_hook: Optional[str] = None): + # Record the configured post-processing hook (TRTLLM-12622) for this worker + # process so the detok chokepoint in DetokenizedGenerationResultBase can + # apply it. + from .postprocessor_hook import set_configured_post_processor_hook + set_configured_post_processor_hook(post_processor_hook) worker = PostprocWorker(feedin_ipc_addr, feedout_ipc_addr, tokenizer_dir=tokenizer_dir, diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py new file mode 100644 index 000000000000..2ecf8927bce1 --- /dev/null +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -0,0 +1,253 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""User-pluggable post-processing hook for ``trtllm-serve`` (TRTLLM-12622). + +This provides a native, per-request, stateful post-processing seam equivalent +to a Triton python-backend post-processor. A user supplies a picklable, +importable callable class via the ``--post_processor`` import path; trtllm +builds one instance per process and invokes it once per output, per streaming +chunk (plus a final call), *after* detokenization and *before* the per-endpoint +response formatter. + +The hook owns its own per-request state (keyed by ``chunk.request_id``) exactly +like Triton's model-managed ``self.sequences = {}`` pattern; trtllm passes only +the request id, the per-chunk payload, lifecycle flags, and the cancel signal. + +This module is intentionally dependency-light (stdlib only) so it can be loaded +in the post-processing worker process and reasoned about in isolation. +""" + +import dataclasses +import importlib +import logging +from typing import List, Optional, Protocol, runtime_checkable + +__all__ = [ + "PostProcChunk", + "PostProcVerdict", + "PostProcessorHook", + "emit", + "suppress", + "terminate", + "apply_post_processor_hook", + "load_post_processor_hook", + "get_post_processor_hook", + "set_configured_post_processor_hook", + "get_configured_post_processor_hook", +] + +logger = logging.getLogger(__name__) + +# Process-level cache so the hook instance is built once per process (mirrors +# the "build once in the worker" precedent of ``record_creator``). Keyed by +# import path; the per-request state lives inside the instance. +_HOOK_INSTANCE_CACHE: dict = {} + +# The post-processing hook is global server config (one pipeline for all +# requests), and detokenization runs in different processes depending on +# ``--num_postprocess_workers`` (the postproc worker process when enabled, the +# proxy/serving process otherwise). Each such process records the configured +# import path here at startup; the detok chokepoint reads it. +_CONFIGURED_HOOK_PATH: Optional[str] = None + + +def set_configured_post_processor_hook(import_path: Optional[str]) -> None: + """Record the configured hook import path for this process (or ``None``).""" + global _CONFIGURED_HOOK_PATH + _CONFIGURED_HOOK_PATH = import_path + + +def get_configured_post_processor_hook() -> Optional[str]: + """Return the hook import path configured for this process, if any.""" + return _CONFIGURED_HOOK_PATH + + +def load_post_processor_hook(import_path: str) -> "PostProcessorHook": + """Build a post-processor hook instance from a dotted import path. + + Mirrors ``tensorrt_llm.tokenizer.load_custom_tokenizer``: resolve + ``module.path.ClassName``, import the module, fetch the class, instantiate + it with no arguments. The class must be importable and picklable so it can + cross the post-processing worker process boundary. + + Args: + import_path: Dotted path to the hook class, e.g. + ``'my_pkg.guardrail.MyPostProcessor'``. + + Returns: + An instance of the hook class. + + Raises: + ValueError: If the path cannot be resolved, imported, or instantiated. + """ + try: + module_path, class_name = import_path.rsplit(".", 1) + module = importlib.import_module(module_path) + hook_class = getattr(module, class_name) + return hook_class() + except (ValueError, ImportError, AttributeError, TypeError) as e: + raise ValueError( + f"Failed to load post-processor hook '{import_path}': {e}. " + "Expected format: 'module.path.ClassName' resolving to a " + "no-arg-constructible callable class." + ) from e + + +def get_post_processor_hook(import_path: str) -> "PostProcessorHook": + """Return the process-cached hook instance for ``import_path``. + + Builds it on first use and reuses it thereafter so per-request state held by + the instance persists across chunks within this process. + """ + hook = _HOOK_INSTANCE_CACHE.get(import_path) + if hook is None: + hook = load_post_processor_hook(import_path) + _HOOK_INSTANCE_CACHE[import_path] = hook + return hook + + +@dataclasses.dataclass +class PostProcChunk: + """The payload handed to the post-processing hook for one output chunk. + + Attributes: + request_id: Stable identifier for the request; the same value is passed + for every chunk of a given response, so the hook can key its own + per-request state on it. + output_index: Index of the output/beam within the request. + text_diff: Newly detokenized text produced by this chunk (streaming). + For non-streaming requests this equals ``text``. + text: Full accumulated detokenized text so far for this output. + token_ids_diff: Newly generated token ids for this chunk. + is_final: True on the terminating call for this output. + aborted: True if the request has been marked aborted in this process + (e.g. a prior ``terminate`` verdict, or an abort observed by the + detok process). Output-side observation only; do not rely on it to + detect every upstream client cancellation. + streaming: True for streaming requests. + """ + + request_id: int + output_index: int + text_diff: str + text: str + token_ids_diff: List[int] + is_final: bool + aborted: bool + streaming: bool + + +@dataclasses.dataclass +class PostProcVerdict: + """The hook's decision for one chunk. + + Use the :func:`emit`, :func:`suppress`, and :func:`terminate` helpers rather + than constructing this directly. + """ + + action: str # "emit" | "suppress" | "terminate" + text: str = "" + reason: Optional[str] = None + + +def emit(text: str) -> PostProcVerdict: + """Emit ``text`` for this chunk (use to rewrite/redact, or pass through).""" + return PostProcVerdict(action="emit", text=text) + + +def suppress() -> PostProcVerdict: + """Withhold this chunk entirely (no client-visible output).""" + return PostProcVerdict(action="suppress") + + +def terminate(reason: str) -> PostProcVerdict: + """Stop the stream for this request. ``reason`` is surfaced as stop_reason.""" + return PostProcVerdict(action="terminate", reason=reason) + + +@runtime_checkable +class PostProcessorHook(Protocol): + """The interface a user post-processor implements. + + The instance is built once per process (its ``__init__`` is the one-time + setup) and called once per output, per chunk. It owns any per-request state + and is responsible for releasing it on ``chunk.is_final``. + + NOTE: rewriting/suppressing text does not rewrite the underlying token ids; + callers that read both text and token ids should expect them to diverge. + """ + + def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: ... + + +def _withhold_token_channel(output) -> None: + """Withhold the raw token-id / logprob channel for this chunk too. + + Without this, a suppressed/terminated chunk would still leak + ``token_ids_diff`` (e.g. on ``/v1/completions`` with ``detokenize=False``) + and ``logprobs_diff`` even though the detokenized text was blanked. + """ + output._last_token_ids_len = len(output.token_ids) + if getattr(output, "logprobs", None) is not None: + output._last_logprobs_len = len(output.logprobs) + + +def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) -> None: + """Run ``hook`` over ``result.outputs`` in place at the detok chokepoint. + + Applies each verdict by rewriting the chunk's text diff on the output + (preserving the already-emitted prefix), suppressing it, or terminating the + stream via the existing abort machinery. + + Hook exceptions are isolated per request: they are logged and the chunk is + passed through unchanged (fail-open), so a buggy hook cannot wedge the + worker or crash the serving loop. This is consistent across the in-proxy and + postproc-worker paths. + """ + is_final = result._done + for output in result.outputs: + chunk = PostProcChunk( + request_id=result.id, + output_index=output.index, + text_diff=output.text_diff, + text=output.text, + token_ids_diff=list(output.token_ids_diff), + is_final=is_final, + aborted=result._aborted, + streaming=streaming, + ) + try: + verdict = hook(chunk) + except Exception: + logger.exception( + "Post-processor hook raised for request %s; passing the chunk through unchanged.", + result.id, + ) + continue + prefix = output.text[: output._last_text_len] + if verdict.action == "emit": + output.text = prefix + verdict.text + elif verdict.action == "suppress": + output.text = prefix + _withhold_token_channel(output) + elif verdict.action == "terminate": + output.text = prefix + verdict.text + _withhold_token_channel(output) + output.finish_reason = "stop" + output.stop_reason = verdict.reason + result._aborted = True + result._done = True + # Cancel the engine request as well. On the in-proxy path this stops + # wasted generation; on the worker path the record's abort() only + # sets the flag and the engine is cancelled by the proxy via + # should_abort. Guarded for results that expose no abort(). + abort = getattr(result, "abort", None) + if callable(abort): + try: + abort() + except Exception: + logger.exception( + "Failed to abort request %s after terminate verdict.", result.id + ) + else: + raise ValueError(f"Unknown post-processor verdict action: {verdict.action!r}") diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 4c2bd2b4cfa8..1853c32037aa 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -877,6 +877,25 @@ def _handle_response(self, response: "GenerationExecutor.Response"): self._done = True break + self._apply_post_processor_hook() + + def _apply_post_processor_hook(self): + """Run the user post-processing hook (TRTLLM-12622) at the detok chokepoint. + + Runs after detok populated ``text``/``text_diff`` and before any + per-endpoint formatter reads them. Shared by the postproc-worker path + and the in-proxy path; the hook is configured per-process and read from + the process-global set at startup. + """ + from .postprocessor_hook import (apply_post_processor_hook, + get_configured_post_processor_hook, + get_post_processor_hook) + import_path = get_configured_post_processor_hook() + if not import_path: + return + hook = get_post_processor_hook(import_path) + apply_post_processor_hook(hook, self, streaming=self._streaming) + # alias PostprocWorker = DetokenizedGenerationResultBase.PostprocWorker diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 62960200ca26..e39533dad5ed 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -281,6 +281,7 @@ def notify_proxy_threads_to_quit(): proxy_result_queue, postproc_worker_config.postprocess_tokenizer_dir, PostprocWorker.default_record_creator, + postproc_worker_config.post_processor_hook, ) postprocess_worker_futures.append(fut) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index f99aa1593c9f..f42064c70056 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -230,6 +230,20 @@ def __init__(self, "yellow") self.mpi_session = self.args.mpi_session + # Record the configured post-processing hook (TRTLLM-12622) for this + # (LLM/proxy) process. When postproc workers are disabled, the detok + # chokepoint runs here on RequestOutput; when enabled, each worker + # process records it separately via postproc_worker_main. + from ..executor.postprocessor_hook import ( + get_post_processor_hook, set_configured_post_processor_hook) + _post_processor_hook = getattr(self.args, "post_processor", None) + set_configured_post_processor_hook(_post_processor_hook) + if _post_processor_hook: + # Resolve eagerly so a bad import path fails fast at startup (and + # primes the per-process build-once cache) rather than erroring on + # the first request. + get_post_processor_hook(_post_processor_hook) + if self.args.parallel_config.is_multi_gpu: if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count( ) < self.args.parallel_config.world_size_per_node: @@ -1417,6 +1431,7 @@ def _build_model(self): postproc_worker_config=PostprocWorkerConfig( num_postprocess_workers=self.args.num_postprocess_workers, postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir, + post_processor_hook=self.args.post_processor, ), is_llm_executor=True) @@ -1565,6 +1580,7 @@ def _build_model(self): postproc_worker_config=PostprocWorkerConfig( num_postprocess_workers=self.args.num_postprocess_workers, postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir, + post_processor_hook=self.args.post_processor, ), is_llm_executor=True, hf_model_dir=self._hf_model_dir, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 7219d59cc76c..7a25dd1053a8 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3118,6 +3118,18 @@ class BaseLlmArgs(StrictBaseModel): "The tokenizer class must implement 'from_pretrained(path, **kwargs)' and the TokenizerBase interface.", status="prototype") + post_processor: Optional[str] = Field( + default=None, + description= + "Python import path of a user post-processing hook applied after " + "detokenization and before the per-endpoint response formatter (e.g. " + "'my_pkg.guardrail.MyPostProcessor'). The class must be importable and " + "picklable, take no constructor arguments, and be callable as " + "'__call__(chunk) -> verdict' (see tensorrt_llm.executor.postprocessor_hook). " + "It runs once per output, per streaming chunk, and may rewrite, " + "suppress, or terminate the output; it owns its own per-request state.", + status="prototype") + skip_tokenizer_init: bool = Field( default=False, description="Whether to skip the tokenizer initialization.") diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 9ae68d9e63e2..5dcf4b2995f8 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -403,6 +403,19 @@ def _init_llm(self, chat_template: Optional[str] = None): else: self.use_harmony = (type(self.model_config).model_type == "gpt_oss") + # The harmony (gpt-oss) path rebuilds the client-visible output from raw + # output token ids rather than the detokenized text, so the + # post-processing hook (TRTLLM-12622), which operates on detok text, + # cannot act there. Fail fast rather than silently bypassing a guardrail. + if self.use_harmony and getattr(self.generator.args, "post_processor", + None): + raise ValueError( + "--post_processor is not supported with harmony/gpt-oss models " + "in this version: the harmony output path is reconstructed from " + "raw token ids and would bypass the text-based hook. Disable the " + "hook or set DISABLE_HARMONY_ADAPTER=1 if the harmony path is " + "not needed.") + self.tool_call_id_type = "random" # default tool call id type is random if self.model_config is not None: # NOTE: Use the instance-level ``model_type`` (JSON-derived) here, not diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index a0cccbfa287e..52081417967a 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -64,6 +64,10 @@ methods: annotation: Optional[str] default: null status: prototype + post_processor: + annotation: Optional[str] + default: null + status: prototype # reasoning reasoning_parser: annotation: Optional[str] diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py new file mode 100644 index 000000000000..9b55cbd29ffa --- /dev/null +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -0,0 +1,247 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the trtllm-serve post-processing hook (TRTLLM-12622).""" + +import pytest + +from tensorrt_llm.executor import postprocessor_hook as _pph +from tensorrt_llm.executor.postprocessor_hook import ( + PostProcChunk, + apply_post_processor_hook, + emit, + get_post_processor_hook, + load_post_processor_hook, + suppress, + terminate, +) +from tensorrt_llm.executor.result import CompletionOutput + + +@pytest.fixture(autouse=True) +def _reset_hook_process_globals(): + """Isolate the module-level process globals between tests.""" + saved_path = _pph.get_configured_post_processor_hook() + saved_cache = dict(_pph._HOOK_INSTANCE_CACHE) + _pph.set_configured_post_processor_hook(None) + _pph._HOOK_INSTANCE_CACHE.clear() + try: + yield + finally: + _pph.set_configured_post_processor_hook(saved_path) + _pph._HOOK_INSTANCE_CACHE.clear() + _pph._HOOK_INSTANCE_CACHE.update(saved_cache) + + +class _FakeResult: + """Minimal stand-in for a GenerationResult at the detok chokepoint.""" + + def __init__(self, outputs, req_id=1, done=False, aborted=False, has_abort=False): + self.outputs = outputs + self.id = req_id + self._done = done + self._aborted = aborted + self.abort_called = 0 + if has_abort: + self.abort = self._abort + + def _abort(self): + self.abort_called += 1 + self._aborted = True + + +def _make_output(text, last_text_len=0, index=0, token_ids=None, logprobs=None): + out = CompletionOutput(index=index, text=text) + out.token_ids = token_ids if token_ids is not None else [1, 2, 3] + out.logprobs = logprobs if logprobs is not None else [-0.1, -0.2, -0.3] + out._last_text_len = last_text_len + return out + + +def test_rewrite_streaming_diff_replaces_only_the_diff(): + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out]) + + def hook(chunk: PostProcChunk): + assert chunk.text_diff == " world" + assert chunk.text == "hello world" + return emit(chunk.text_diff.upper()) + + apply_post_processor_hook(hook, result, streaming=True) + + # The already-emitted prefix is preserved; only the diff is rewritten. + assert out.text == "hello WORLD" + assert out.text_diff == " WORLD" + + +def test_suppress_withholds_the_diff_keeping_prefix(): + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out]) + + apply_post_processor_hook(lambda c: suppress(), result, streaming=True) + + assert out.text == "hello" + assert out.text_diff == "" + + +def test_terminate_marks_aborted_and_done_with_reason(): + out = _make_output("safe bad", last_text_len=len("safe ")) + result = _FakeResult([out]) + + apply_post_processor_hook(lambda c: terminate("policy_violation"), result, streaming=True) + + assert result._aborted is True + assert result._done is True + assert out.finish_reason == "stop" + assert out.stop_reason == "policy_violation" + # The violating diff is withheld. + assert out.text == "safe " + + +def test_suppress_blanks_token_and_logprob_diffs(): + # mid-stream chunk with one new token/logprob beyond what was emitted. + out = _make_output("hello world", last_text_len=len("hello")) + out._last_token_ids_len = 2 + out._last_logprobs_len = 2 + assert out.token_ids_diff == [3] + assert out.logprobs_diff == [-0.3] + result = _FakeResult([out]) + + apply_post_processor_hook(lambda c: suppress(), result, streaming=True) + + assert out.text == "hello" + # The raw token/logprob channel is withheld too, not just the text. + assert out.token_ids_diff == [] + assert out.logprobs_diff == [] + + +def test_terminate_calls_abort_when_available_and_blanks_token_channel(): + out = _make_output("safe bad", last_text_len=len("safe ")) + out._last_token_ids_len = 1 + result = _FakeResult([out], has_abort=True) + + apply_post_processor_hook(lambda c: terminate("policy"), result, streaming=True) + + assert result._aborted is True + assert result._done is True + assert result.abort_called == 1 + assert out.token_ids_diff == [] + assert out.finish_reason == "stop" + assert out.stop_reason == "policy" + + +def test_terminate_without_abort_attr_does_not_crash(): + out = _make_output("safe bad", last_text_len=len("safe ")) + result = _FakeResult([out], has_abort=False) + apply_post_processor_hook(lambda c: terminate("policy"), result, streaming=True) + assert result._aborted is True + assert result._done is True + + +def test_hook_exception_fails_open_passthrough(): + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out]) + + def boom(chunk): + raise RuntimeError("hook bug") + + # Must not propagate; the chunk passes through unchanged (fail-open). + apply_post_processor_hook(boom, result, streaming=True) + + assert out.text == "hello world" + assert out.text_diff == " world" + + +def test_non_streaming_rewrites_full_text(): + # Non-stream single response: _last_text_len == 0, so diff == full text. + out = _make_output("the full answer", last_text_len=0) + result = _FakeResult([out], done=True) + + def hook(chunk: PostProcChunk): + assert chunk.text_diff == chunk.text == "the full answer" + return emit("REDACTED") + + apply_post_processor_hook(hook, result, streaming=False) + + assert out.text == "REDACTED" + + +def test_passthrough_emit_is_idempotent(): + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out]) + + apply_post_processor_hook(lambda c: emit(c.text_diff), result, streaming=True) + + assert out.text == "hello world" + assert out.text_diff == " world" + + +def test_per_request_state_is_keyed_by_request_id(): + """The hook owns its per-request state; trtllm only passes request_id.""" + + class Counter: + def __init__(self): + self.state = {} + + def __call__(self, chunk: PostProcChunk): + n = self.state.get(chunk.request_id, 0) + 1 + self.state[chunk.request_id] = n + if chunk.is_final: + self.state.pop(chunk.request_id, None) + return emit(f"{chunk.text_diff}#{n}") + + hook = Counter() + r1a = _FakeResult([_make_output("a", 0)], req_id=1) + r2 = _FakeResult([_make_output("b", 0)], req_id=2) + r1b = _FakeResult([_make_output("ac", 1)], req_id=1, done=True) + + apply_post_processor_hook(hook, r1a, streaming=True) + apply_post_processor_hook(hook, r2, streaming=True) + apply_post_processor_hook(hook, r1b, streaming=True) + + # Request 1 counts 1 then 2, independently of request 2 (which counts 1). + assert r1a.outputs[0].text == "a#1" + assert r2.outputs[0].text == "b#1" + assert r1b.outputs[0].text == "ac#2" + # State released on the final chunk. + assert 1 not in hook.state + + +def test_unknown_verdict_action_raises(): + from tensorrt_llm.executor.postprocessor_hook import PostProcVerdict + + out = _make_output("x", 0) + result = _FakeResult([out]) + with pytest.raises(ValueError, match="Unknown post-processor verdict"): + apply_post_processor_hook(lambda c: PostProcVerdict(action="bogus"), result, streaming=True) + + +def test_loader_resolves_import_path(): + # Any importable, no-arg-constructible class works as a smoke test. + hook = load_post_processor_hook("collections.OrderedDict") + assert hook is not None + + +def test_loader_raises_on_bad_path(): + with pytest.raises(ValueError, match="Failed to load post-processor hook"): + load_post_processor_hook("no.such.module.Nope") + + +def test_get_hook_builds_once_per_process(): + a = get_post_processor_hook("collections.OrderedDict") + b = get_post_processor_hook("collections.OrderedDict") + assert a is b + + +def test_configured_hook_global_roundtrip(): + from tensorrt_llm.executor.postprocessor_hook import ( + get_configured_post_processor_hook, + set_configured_post_processor_hook, + ) + + try: + assert get_configured_post_processor_hook() is None + set_configured_post_processor_hook("my_pkg.guardrail.G") + assert get_configured_post_processor_hook() == "my_pkg.guardrail.G" + finally: + set_configured_post_processor_hook(None) + assert get_configured_post_processor_hook() is None From 5a7dbbaf0244b0b0b88ce401940b408e6f356e7f Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 10 Jun 2026 14:03:06 -0700 Subject: [PATCH 02/16] [TRTLLM-12622][test] Add e2e tests for trtllm-serve post-processing hook Launch a real trtllm-serve with --post_processor and assert the client-visible effect (rewrite / suppress / terminate) across the chat and completions endpoints, streaming and non-streaming, with the postproc worker pool both disabled (in-proxy detok) and enabled (worker-process detok). - _postproc_hook_samples.py: stateless deterministic sample hooks (UppercaseHook / SuppressHook / TerminateHook). - _test_openai_post_processor.py: the endpoint matrix on TinyLlama-1.1B. - test_e2e.py::test_openai_post_processor wrapper + l0_a10 test-list entry. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- tests/integration/defs/test_e2e.py | 7 + .../integration/test_lists/test-db/l0_a10.yml | 1 + .../llmapi/apps/_postproc_hook_samples.py | 40 +++++ .../apps/_test_openai_post_processor.py | 152 ++++++++++++++++++ 4 files changed, 200 insertions(+) create mode 100644 tests/unittest/llmapi/apps/_postproc_hook_samples.py create mode 100644 tests/unittest/llmapi/apps/_test_openai_post_processor.py diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 64a40cebc21c..3ac865413622 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -902,6 +902,13 @@ def test_openai_tool_call(llm_root, llm_venv): str(test_root / "_test_openai_tool_call.py")]) +def test_openai_post_processor(llm_root, llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd( + ["-m", "pytest", + str(test_root / "_test_openai_post_processor.py")]) + + @pytest.mark.parametrize("sampler", ["torch_sampler", "trtllm_sampler"]) def test_openai_completions_with_logit_bias(llm_root, llm_venv, sampler: str): test_root = unittest_path() / "llmapi" / "apps" diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 440fe2885fab..5124e2f11fbf 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -105,6 +105,7 @@ l0_a10: - test_e2e.py::test_openai_misc_example[pytorch] - test_e2e.py::test_openai_reasoning[pytorch] - test_e2e.py::test_openai_tool_call + - test_e2e.py::test_openai_post_processor - test_e2e.py::test_openai_responses_entrypoint - test_e2e.py::test_openai_completions_example[pytorch] - test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90) diff --git a/tests/unittest/llmapi/apps/_postproc_hook_samples.py b/tests/unittest/llmapi/apps/_postproc_hook_samples.py new file mode 100644 index 000000000000..24f978a4d7e8 --- /dev/null +++ b/tests/unittest/llmapi/apps/_postproc_hook_samples.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Sample post-processing hooks for the trtllm-serve hook integration test +(TRTLLM-12622). + +These are deliberately stateless and deterministic so the test can assert the +client-visible effect regardless of the (non-deterministic) model output. Each +class is a top-level, no-arg-constructible, importable callable so it can be +supplied to ``trtllm-serve --post_processor`` and reconstructed by reference in +the post-processing worker process. +""" + +from tensorrt_llm.executor.postprocessor_hook import ( + PostProcChunk, + PostProcVerdict, + emit, + suppress, + terminate, +) + + +class UppercaseHook: + """Rewrite every chunk's text to upper case.""" + + def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + return emit(chunk.text_diff.upper()) + + +class SuppressHook: + """Withhold all output (every chunk is suppressed).""" + + def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + return suppress() + + +class TerminateHook: + """Terminate the stream immediately on the first chunk seen.""" + + def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + return terminate("test_policy") diff --git a/tests/unittest/llmapi/apps/_test_openai_post_processor.py b/tests/unittest/llmapi/apps/_test_openai_post_processor.py new file mode 100644 index 000000000000..09d2474f7eac --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_post_processor.py @@ -0,0 +1,152 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""End-to-end tests for the trtllm-serve post-processing hook (TRTLLM-12622). + +Launches a real ``trtllm-serve`` with ``--post_processor`` pointing at one of +the sample hooks in ``_postproc_hook_samples`` and asserts the client-visible +effect (rewrite / suppress / terminate) across the chat and completions +endpoints, streaming and non-streaming, with the postproc worker pool both +disabled (in-proxy detok) and enabled (worker-process detok). + +The hooks are stateless and deterministic so assertions hold regardless of the +model's (non-deterministic) output: + - UppercaseHook: every chunk is upper-cased -> output == output.upper() + - SuppressHook: every chunk is withheld -> output == "" + - TerminateHook: first chunk terminates -> output == "", stops early +""" + +import os + +import openai +import pytest + +from ..test_llm import get_model_path +from .openai_server import RemoteOpenAIServer + +pytestmark = pytest.mark.threadleak(enabled=False) + +# Dotted import paths into the sample-hook module shipped alongside this test. +_HOOKS = { + "uppercase": "_postproc_hook_samples.UppercaseHook", + "suppress": "_postproc_hook_samples.SuppressHook", + "terminate": "_postproc_hook_samples.TerminateHook", +} + + +@pytest.fixture(scope="module") +def model_name(): + return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + + +@pytest.fixture(scope="module", params=[0, 2], ids=["disable_processpool", "enable_processpool"]) +def num_postprocess_workers(request): + return request.param + + +@pytest.fixture(scope="module", params=list(_HOOKS), ids=list(_HOOKS)) +def hook(request): + return request.param + + +@pytest.fixture(scope="module") +def server(model_name: str, num_postprocess_workers: int, hook: str): + model_path = get_model_path(model_name) + args = [ + "--backend", + "pytorch", + # co-exist with other servers + "--kv_cache_free_gpu_memory_fraction", + "0.2", + "--num_postprocess_workers", + f"{num_postprocess_workers}", + "--post_processor", + _HOOKS[hook], + ] + # Make the sample-hook module importable by the server (and its postproc + # worker) subprocesses. + apps_dir = os.path.dirname(os.path.abspath(__file__)) + env = os.environ.copy() + env["PYTHONPATH"] = apps_dir + os.pathsep + env.get("PYTHONPATH", "") + with RemoteOpenAIServer(model_path, args, env=env) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_client() + + +@pytest.fixture(scope="module") +def async_client(server: RemoteOpenAIServer): + return server.get_async_client() + + +def _assert_text_matches_hook(hook: str, text: str): + if hook == "uppercase": + assert text == text.upper(), f"expected upper-cased text, got {text!r}" + elif hook == "suppress": + assert text == "", f"expected suppressed (empty) text, got {text!r}" + elif hook == "terminate": + assert text == "", f"expected terminated (empty) text, got {text!r}" + else: + raise AssertionError(f"unknown hook {hook}") + + +def test_completions_non_streaming(client: openai.OpenAI, model_name: str, hook: str): + completion = client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=16, + temperature=0.0, + ) + text = completion.choices[0].text + _assert_text_matches_hook(hook, text) + if hook == "terminate": + assert completion.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio(loop_scope="module") +async def test_completions_streaming(async_client: openai.AsyncOpenAI, model_name: str, hook: str): + stream = await async_client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=16, + temperature=0.0, + stream=True, + ) + text = "" + async for chunk in stream: + token = chunk.choices[0].text + if token: + text += token + _assert_text_matches_hook(hook, text) + + +def test_chat_non_streaming(client: openai.OpenAI, model_name: str, hook: str): + chat = client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "Hello, tell me a short story."}], + max_tokens=16, + temperature=0.0, + ) + content = chat.choices[0].message.content or "" + _assert_text_matches_hook(hook, content) + if hook == "terminate": + assert chat.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio(loop_scope="module") +async def test_chat_streaming(async_client: openai.AsyncOpenAI, model_name: str, hook: str): + stream = await async_client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "Hello, tell me a short story."}], + max_tokens=16, + temperature=0.0, + stream=True, + ) + content = "" + async for chunk in stream: + delta = chunk.choices[0].delta.content + if delta: + content += delta + _assert_text_matches_hook(hook, content) From 11c5f23d942feb993b0fc9831bdbf083122ab14f Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 10 Jun 2026 16:42:55 -0700 Subject: [PATCH 03/16] [TRTLLM-12622][fix] Drop late proxy responses after post-processor hook terminate When a post-processing hook returns `terminate`, the result is marked done and popped from the proxy's `_results` map, but the engine can still emit in-flight responses for the same client_id (abort is async, and the postproc worker recreates a record for any late response). Those late responses reached `process_res` after the result was removed, raising KeyError and recording a fatal engine error that tore down the whole engine. Look the result up with `.get()` and drop late responses for already-finalized client_ids; make the corresponding pop idempotent. Verified on GB200: the previously-failing test_chat_streaming[terminate-enable_processpool] now passes (post-processor hook e2e suite 24/24). Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- tensorrt_llm/executor/proxy.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index a0cb59111c9d..b9298640cb8f 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -327,7 +327,16 @@ def process_res(res): nonlocal event_loop nonlocal async_queues - queue = self._results[client_id].queue + # The result may have already been finalized and removed — e.g. a + # post-processor hook (TRTLLM-12622) returned `terminate`, which + # marks the result done and pops it here, while the engine still has + # in-flight responses for the same client_id (abort is async, and the + # postproc worker recreates a record for any late response). Drop + # such late responses instead of crashing with a KeyError. + result = self._results.get(client_id) + if result is None: + return + queue = result.queue if isinstance(queue, _SyncQueue): queue.put_nowait(res) async_queues.append(queue) @@ -342,7 +351,7 @@ def process_res(res): res, ErrorResponse) or (isinstance(res, PostprocWorker.Output) and res.is_final): - self._results.pop(client_id) + self._results.pop(client_id, None) res = res if isinstance(res, list) else [res] From 1146de68277b41ebb0168c48ec0ede2c0cdf8460 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 10 Jun 2026 16:54:48 -0700 Subject: [PATCH 04/16] [TRTLLM-12622][doc] Document the trtllm-serve post-processing hook Add a feature doc covering the native post-processing hook: how to enable it via --post_processor (CLI and YAML), the PostProcChunk / verdict interface, three worked examples (rewrite, stateful guardrail with terminate, suppress), per-request state guidance, supported endpoints and limitations, and pointers to the unit and e2e tests. Register it in the features toctree. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- docs/source/features/post-processor-hook.md | 181 ++++++++++++++++++++ docs/source/index.rst | 1 + 2 files changed, 182 insertions(+) create mode 100644 docs/source/features/post-processor-hook.md diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md new file mode 100644 index 000000000000..c0761ff67d8e --- /dev/null +++ b/docs/source/features/post-processor-hook.md @@ -0,0 +1,181 @@ +(post-processor-hook)= + +# Post-Processing Hook + +`trtllm-serve` supports a user-supplied **post-processing hook**: a native, per-request seam that runs +on each generated output *after* detokenization and *before* the per-endpoint response formatter. It +lets a deployment rewrite, redact, suppress, or terminate model output — including stateful logic that +spans the chunks of a streamed response — without modifying TensorRT LLM source. + +The hook is a plain Python callable class supplied by import path, mirroring `--custom_tokenizer`. It +is built once per process and invoked once per output, per streaming chunk (plus a final call), so it +can hold its own per-request state. + +```{note} +This feature is a prototype and its interface may change in a future release. +``` + +For the interface definitions referenced below, see +[`tensorrt_llm/executor/postprocessor_hook.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/executor/postprocessor_hook.py). + +## Enabling the hook + +Pass the dotted import path of your hook class to `--post_processor`: + +```bash +trtllm-serve --post_processor my_pkg.guardrail.MyPostProcessor +``` + +Equivalently, set it in a YAML config passed via `--extra_llm_api_options`: + +```yaml +post_processor: my_pkg.guardrail.MyPostProcessor +``` + +The class must be: + +- **Importable** — installed (`pip install`) or on `PYTHONPATH` when the server (and its + post-processing worker processes, if any) start. +- **Picklable and no-argument-constructible** — the hook is reconstructed by reference inside each + process; `__init__` takes no required arguments and is the one-time setup point. + +The hook works the same with or without the out-of-process post-processing worker pool +(`--num_postprocess_workers`). + +## The hook interface + +A hook implements a single method, `__call__(self, chunk) -> verdict`: + +```python +from tensorrt_llm.executor.postprocessor_hook import ( + PostProcChunk, + PostProcVerdict, + emit, + suppress, + terminate, +) + + +class MyPostProcessor: + def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + return emit(chunk.text_diff) # pass through unchanged +``` + +### `PostProcChunk` + +The payload handed to the hook for one output chunk: + +| Field | Description | +|-------|-------------| +| `request_id` | Stable identifier for the request; the same value is passed for every chunk of a response, so the hook can key its per-request state on it. | +| `output_index` | Index of the output/beam within the request. | +| `text_diff` | Newly detokenized text produced by this chunk (streaming). For non-streaming requests this equals `text`. | +| `text` | Full accumulated detokenized text so far for this output. | +| `token_ids_diff` | Newly generated token ids for this chunk. | +| `is_final` | `True` on the terminating call for this output. | +| `aborted` | `True` if the request has been marked aborted in this process. Output-side observation only. | +| `streaming` | `True` for streaming requests. | + +### Verdicts + +Return one of the following helpers from `__call__`: + +| Helper | Effect | +|--------|--------| +| `emit(text)` | Emit `text` for this chunk. Use it to pass output through unchanged (`emit(chunk.text_diff)`) or to rewrite/redact it. | +| `suppress()` | Withhold this chunk entirely (no client-visible output for it). | +| `terminate(reason)` | Stop the stream for this request. `reason` is surfaced as the response `stop_reason`, and the engine request is cancelled. | + +## Usage examples + +### Rewrite output + +A stateless hook that upper-cases every chunk: + +```python +from tensorrt_llm.executor.postprocessor_hook import PostProcChunk, PostProcVerdict, emit + + +class UpperCaseHook: + def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + return emit(chunk.text_diff.upper()) +``` + +### Stateful guardrail that terminates on a banned phrase + +This hook accumulates text per request (keyed by `request_id`), stops the stream as soon as a banned +phrase appears, and releases its state when the request finishes: + +```python +from tensorrt_llm.executor.postprocessor_hook import ( + PostProcChunk, PostProcVerdict, emit, terminate, +) + + +class BannedPhraseGuard: + _BANNED = ("forbidden phrase",) + + def __init__(self): + # Per-request accumulators owned entirely by the hook. + self._buffers: dict[int, str] = {} + + def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + buffer = self._buffers.get(chunk.request_id, "") + chunk.text_diff.lower() + self._buffers[chunk.request_id] = buffer + + if any(phrase in buffer for phrase in self._BANNED): + self._buffers.pop(chunk.request_id, None) + return terminate("banned_phrase") + + if chunk.is_final: + self._buffers.pop(chunk.request_id, None) + return emit(chunk.text_diff) +``` + +### Suppress output + +A hook that withholds all client-visible text: + +```python +from tensorrt_llm.executor.postprocessor_hook import PostProcChunk, PostProcVerdict, suppress + + +class SuppressHook: + def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + return suppress() +``` + +## Per-request state + +The hook instance is built once per process and shared across all requests handled by that process, so +any per-request state must be keyed by `chunk.request_id` and released when `chunk.is_final` is seen +(or after a `terminate`). State is not shared across processes; when the post-processing worker pool is +enabled, all chunks of a single request are still routed to the same worker, so per-request state +remains consistent for that request. + +## Supported endpoints and limitations + +- **Endpoints**: `chat/completions` and `completions`, both streaming and non-streaming. The hook also + applies to the `responses` endpoint for non-harmony models. +- **harmony / gpt-oss models**: not supported. Because the harmony output path is reconstructed from + raw token ids, it would bypass the text-based hook, so the server fails fast at startup when + `--post_processor` is combined with a harmony model. +- **Text vs. token ids**: rewriting or suppressing text does not rewrite the underlying `token_ids` or + `logprobs`. Clients that read both should expect them to diverge. +- **Reasoning / tool parsing**: the hook runs before the reasoning and tool-call parsers. A hook that + rewrites or suppresses text may desync those parsers; prefer `terminate`, or apply such hooks to + plain-text requests. + +## Tests + +The hook's unit and end-to-end tests double as runnable usage examples: + +```bash +# Unit tests for the hook contract (rewrite / suppress / terminate, per-request state, loader) +pytest tests/unittest/executor/test_postprocessor_hook.py -v +``` + +- End-to-end serving tests across endpoints, streaming modes, and worker-pool settings: + `tests/unittest/llmapi/apps/_test_openai_post_processor.py`. +- The deterministic sample hooks used by those tests: + `tests/unittest/llmapi/apps/_postproc_hook_samples.py`. diff --git a/docs/source/index.rst b/docs/source/index.rst index 5b5a163278f3..595a3e5b095b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -75,6 +75,7 @@ Welcome to TensorRT LLM's Documentation! features/quantization.md features/sampling.md features/additional-outputs.md + features/post-processor-hook.md features/guided-decoding.md features/speculative-decoding.md features/checkpoint-loading.md From b46c4e74219a403909d8f7450c904fa4f1f4a32c Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Wed, 10 Jun 2026 17:05:17 -0700 Subject: [PATCH 05/16] [TRTLLM-12622][chore] Address post-processor hook review feedback - test: streaming terminate cases now assert finish_reason == "stop", so the hook (not an empty generation) is verified to be what stopped the stream (the cell that previously hit the proxy KeyError race). - refactor: hoist the post-processor-hook import in result.py to module scope, off the per-chunk detok hot path (postprocessor_hook is stdlib-only, so there is no circular import). - docs: clarify that PostProcChunk.is_final is request-level. No behavior change to the serving path. GPU re-validation pending. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- tensorrt_llm/executor/postprocessor_hook.py | 6 ++++++ tensorrt_llm/executor/result.py | 6 +++--- .../llmapi/apps/_test_openai_post_processor.py | 16 ++++++++++++++++ 3 files changed, 25 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index 2ecf8927bce1..523c4fd84e8d 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -204,6 +204,12 @@ def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) worker or crash the serving loop. This is consistent across the in-proxy and postproc-worker paths. """ + # ``is_final`` is request-level (``result._done``), not per-output: a + # ``terminate`` verdict on one output marks the whole request done. Under the + # locked 1:1 single-output scope (TRTLLM-12622 §15.1) this is exact; hooks key + # their per-request state on ``request_id`` and release it on ``is_final``, so + # request-level finality is the correct cleanup signal regardless of output + # count. is_final = result._done for output in result.outputs: chunk = PostProcChunk( diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 1853c32037aa..3e5a9acb24e6 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -30,6 +30,9 @@ from ..metrics.perf_utils import \ process_req_perf_metrics as _process_req_perf_metrics from ..sampling_params import LogprobParams, SamplingParams +from .postprocessor_hook import (apply_post_processor_hook, + get_configured_post_processor_hook, + get_post_processor_hook) from .utils import ErrorResponse, has_event_loop, is_llm_response if TYPE_CHECKING: @@ -887,9 +890,6 @@ def _apply_post_processor_hook(self): and the in-proxy path; the hook is configured per-process and read from the process-global set at startup. """ - from .postprocessor_hook import (apply_post_processor_hook, - get_configured_post_processor_hook, - get_post_processor_hook) import_path = get_configured_post_processor_hook() if not import_path: return diff --git a/tests/unittest/llmapi/apps/_test_openai_post_processor.py b/tests/unittest/llmapi/apps/_test_openai_post_processor.py index 09d2474f7eac..0be202efefff 100644 --- a/tests/unittest/llmapi/apps/_test_openai_post_processor.py +++ b/tests/unittest/llmapi/apps/_test_openai_post_processor.py @@ -115,11 +115,19 @@ async def test_completions_streaming(async_client: openai.AsyncOpenAI, model_nam stream=True, ) text = "" + finish_reason = None async for chunk in stream: token = chunk.choices[0].text if token: text += token + if chunk.choices[0].finish_reason is not None: + finish_reason = chunk.choices[0].finish_reason _assert_text_matches_hook(hook, text) + if hook == "terminate": + # The hook (not an empty generation) must have stopped the stream. + assert finish_reason == "stop", ( + f"expected finish_reason 'stop' from terminate, got {finish_reason!r}" + ) def test_chat_non_streaming(client: openai.OpenAI, model_name: str, hook: str): @@ -145,8 +153,16 @@ async def test_chat_streaming(async_client: openai.AsyncOpenAI, model_name: str, stream=True, ) content = "" + finish_reason = None async for chunk in stream: delta = chunk.choices[0].delta.content if delta: content += delta + if chunk.choices[0].finish_reason is not None: + finish_reason = chunk.choices[0].finish_reason _assert_text_matches_hook(hook, content) + if hook == "terminate": + # The hook (not an empty generation) must have stopped the stream. + assert finish_reason == "stop", ( + f"expected finish_reason 'stop' from terminate, got {finish_reason!r}" + ) From 959f2daf1ff69e4cd1b7c887487b75392f43db71 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Mon, 15 Jun 2026 13:17:35 -0700 Subject: [PATCH 06/16] [TRTLLM-12622][chore] Fail-fast hook registration + review fixes Address PR review feedback on the trtllm-serve post-processing hook: - Fail fast on conflicting in-process hook registration: a second LLM in the same process with a different --post_processor now raises instead of silently clobbering the process-global and applying the wrong hook to the already-running instance. Re-registering the same path is a no-op; clearing to None (e.g. on shutdown) is always allowed. - Validate the import path BEFORE recording the process-global in BaseLLM.__init__, so a bad path leaves no stale registration; clear the global in BaseLLM.shutdown() so sequential LLMs still work. - Add a focused unit test for the proxy late-response drop (the .get() at dispatch_result_task): a response for an already-popped client_id after a terminate is dropped without KeyError. - Doc: clarify that engine batching is transparent (one call per request, keyed by request_id). - Drop an internal study-plan section reference from a code comment. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- docs/source/features/post-processor-hook.md | 6 ++ tensorrt_llm/executor/postprocessor_hook.py | 28 +++++++- tensorrt_llm/llmapi/llm.py | 19 ++++- .../executor/test_postprocessor_hook.py | 31 ++++++++ .../executor/test_proxy_postproc_terminate.py | 71 +++++++++++++++++++ 5 files changed, 150 insertions(+), 5 deletions(-) create mode 100644 tests/unittest/executor/test_proxy_postproc_terminate.py diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md index c0761ff67d8e..93b9f4e5c06f 100644 --- a/docs/source/features/post-processor-hook.md +++ b/docs/source/features/post-processor-hook.md @@ -153,6 +153,12 @@ any per-request state must be keyed by `chunk.request_id` and released when `chu enabled, all chunks of a single request are still routed to the same worker, so per-request state remains consistent for that request. +Engine-level batching is transparent to the hook: even when many requests are batched and run together +in the engine, the hook is still invoked **once per request** (per output, per chunk), with +`chunk.request_id` identifying which request the chunk belongs to. There is no batched-call form — the +hook never receives more than one request's data in a single call, so keying state on `request_id` is +sufficient to keep concurrent requests isolated. + ## Supported endpoints and limitations - **Endpoints**: `chat/completions` and `completions`, both streaming and non-streaming. The hook also diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index 523c4fd84e8d..a39cf3e054b9 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -52,8 +52,30 @@ def set_configured_post_processor_hook(import_path: Optional[str]) -> None: - """Record the configured hook import path for this process (or ``None``).""" + """Record the configured hook import path for this process (or ``None``). + + The hook is process-global server configuration (one pipeline for all + requests in this process), so registering a *different* non-``None`` hook + while one is already active is rejected: a second ``LLM`` in the same + process with a conflicting ``--post_processor`` would otherwise silently + start applying the wrong hook to the already-running instance's responses. + Re-registering the same path is a no-op, and clearing to ``None`` (e.g. on + shutdown) is always allowed. In normal serving each server/rank process + hosts a single ``LLM``, so this only guards genuinely mixed-LLM processes. + """ global _CONFIGURED_HOOK_PATH + if ( + import_path is not None + and _CONFIGURED_HOOK_PATH is not None + and import_path != _CONFIGURED_HOOK_PATH + ): + raise RuntimeError( + "A different post-processor hook is already registered in this " + f"process ({_CONFIGURED_HOOK_PATH!r}); refusing to register " + f"{import_path!r}. The post-processing hook is process-global " + "configuration; running multiple LLMs with different " + "--post_processor values in one process is not supported." + ) _CONFIGURED_HOOK_PATH = import_path @@ -206,8 +228,8 @@ def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) """ # ``is_final`` is request-level (``result._done``), not per-output: a # ``terminate`` verdict on one output marks the whole request done. Under the - # locked 1:1 single-output scope (TRTLLM-12622 §15.1) this is exact; hooks key - # their per-request state on ``request_id`` and release it on ``is_final``, so + # locked 1:1 single-output scope (TRTLLM-12622) this is exact; hooks key their + # per-request state on ``request_id`` and release it on ``is_final``, so # request-level finality is the correct cleanup signal regardless of output # count. is_final = result._done diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index f42064c70056..081cb0ebfc97 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -237,12 +237,18 @@ def __init__(self, from ..executor.postprocessor_hook import ( get_post_processor_hook, set_configured_post_processor_hook) _post_processor_hook = getattr(self.args, "post_processor", None) - set_configured_post_processor_hook(_post_processor_hook) if _post_processor_hook: # Resolve eagerly so a bad import path fails fast at startup (and # primes the per-process build-once cache) rather than erroring on - # the first request. + # the first request. Validate BEFORE recording the process-global so + # a failed import leaves no stale hook registered for this process. get_post_processor_hook(_post_processor_hook) + # Record the configured hook for this process. Raises if a different + # hook is already registered in-process (mixed-LLM process), so a second + # instance can never silently apply the wrong hook to the first's + # responses. + set_configured_post_processor_hook(_post_processor_hook) + self._post_processor_hook = _post_processor_hook if self.args.parallel_config.is_multi_gpu: if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count( @@ -1204,6 +1210,15 @@ def shutdown(self) -> None: self.mpi_session.shutdown() self.mpi_session = None + # Release this process's post-processor hook registration so a + # subsequent LLM with a different --post_processor can be created in the + # same process (e.g. sequential use in tests or notebooks). + if getattr(self, "_post_processor_hook", None): + from ..executor.postprocessor_hook import \ + set_configured_post_processor_hook + set_configured_post_processor_hook(None) + self._post_processor_hook = None + def _check_health(self) -> bool: """Check if the LLM is healthy. diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py index 9b55cbd29ffa..5b771aa90e7b 100644 --- a/tests/unittest/executor/test_postprocessor_hook.py +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -245,3 +245,34 @@ def test_configured_hook_global_roundtrip(): finally: set_configured_post_processor_hook(None) assert get_configured_post_processor_hook() is None + + +def test_configured_hook_rejects_conflicting_registration(): + """A conflicting in-process hook must fail fast (TRTLLM-12622). + + A second, different hook in the same process is rejected rather than + silently clobbering the global and applying the wrong hook. + """ + from tensorrt_llm.executor.postprocessor_hook import ( + get_configured_post_processor_hook, + set_configured_post_processor_hook, + ) + + try: + set_configured_post_processor_hook("my_pkg.guardrail.A") + # Re-registering the same path is an idempotent no-op. + set_configured_post_processor_hook("my_pkg.guardrail.A") + assert get_configured_post_processor_hook() == "my_pkg.guardrail.A" + + # A different non-None hook is rejected and the active one is unchanged. + with pytest.raises(RuntimeError, match="already registered"): + set_configured_post_processor_hook("my_pkg.guardrail.B") + assert get_configured_post_processor_hook() == "my_pkg.guardrail.A" + + # Clearing to None is always allowed (e.g. on shutdown), after which a + # new hook can be registered. + set_configured_post_processor_hook(None) + set_configured_post_processor_hook("my_pkg.guardrail.B") + assert get_configured_post_processor_hook() == "my_pkg.guardrail.B" + finally: + set_configured_post_processor_hook(None) diff --git a/tests/unittest/executor/test_proxy_postproc_terminate.py b/tests/unittest/executor/test_proxy_postproc_terminate.py new file mode 100644 index 000000000000..6f03526e5f1f --- /dev/null +++ b/tests/unittest/executor/test_proxy_postproc_terminate.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Regression tests for proxy late-response handling after a hook terminate. + +When a post-processor hook (TRTLLM-12622) returns ``terminate`` the result is +marked done and popped from the +proxy's ``_results`` map, but the engine may still have in-flight responses for +the same ``client_id`` (abort is async). ``GenerationExecutorProxy.dispatch_result_task`` +must drop those late responses via ``self._results.get(client_id)`` rather than +``pop``-ing a missing key and crashing the dispatch thread with a ``KeyError``. +""" + +from types import SimpleNamespace + +from tensorrt_llm.executor.proxy import GenerationExecutorProxy + + +class _FakeResultQueue: + """Returns one queued item per ``get()`` call (mirrors the IPC queue).""" + + def __init__(self, items): + self._items = list(items) + + def get(self): + return self._items.pop(0) + + +class _RecordingQueue: + """Stand-in for a result's delivery queue (non-``_SyncQueue`` branch).""" + + def __init__(self): + self.delivered = [] + + def put(self, res): + self.delivered.append(res) + + +def _make_proxy(): + # Avoid GenerationExecutorProxy.__init__ (it spawns workers); we only + # exercise the pure dispatch logic with hand-set attributes. + return object.__new__(GenerationExecutorProxy) + + +def test_late_response_after_terminate_is_dropped_without_keyerror(): + """A response for an already-popped client_id must be dropped, not crash.""" + proxy = _make_proxy() + proxy._results = {} # terminate already finalized + popped this client_id + + late = SimpleNamespace(client_id=999, has_error=False, result=SimpleNamespace(is_final=True)) + proxy.result_queue = _FakeResultQueue([late]) + + # Must not raise KeyError; the dispatch loop stays alive (returns True). + assert proxy.dispatch_result_task() is True + + +def test_final_response_is_delivered_and_popped(): + """A live, final response is delivered and removed from ``_results``. + + This establishes the exact condition under which a later duplicate becomes + the 'late response' that the drop above guards against. + """ + proxy = _make_proxy() + queue = _RecordingQueue() + proxy._results = {7: SimpleNamespace(queue=queue)} + + resp = SimpleNamespace(client_id=7, has_error=False, result=SimpleNamespace(is_final=True)) + proxy.result_queue = _FakeResultQueue([resp]) + + assert proxy.dispatch_result_task() is True + assert queue.delivered == [resp] + assert 7 not in proxy._results # popped on is_final From 278da1f6c5e9a9ef4fca4207e9359e3c5fdb20c2 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Mon, 15 Jun 2026 14:12:23 -0700 Subject: [PATCH 07/16] [TRTLLM-12622][refactor] Scope post-processing hook to the LLM instance Replace the process-global post-processing hook ownership with per-instance ownership, so the hook is a property of the LLM (and of each post-processing worker process), like the tokenizer it sits next to in BaseLlmArgs. - postprocessor_hook.py: drop the process globals and their accessors (_CONFIGURED_HOOK_PATH, set_/get_configured_post_processor_hook, _HOOK_INSTANCE_CACHE, get_post_processor_hook) and the fail-fast conflict guard. Keep load_post_processor_hook / apply_post_processor_hook. - result.py: DetokenizedGenerationResultBase carries its own _post_processor_hook; the detok chokepoint applies that instance attribute instead of reading a process global. - llm.py: BaseLLM builds its own hook instance (eager import validation retained) and threads it onto the RequestOutput via _from_generation_result; remove the global registration and shutdown-clear. - postproc_worker.py: each PostprocWorker builds and owns one hook instance (mirroring the tokenizer) and injects it onto every record it creates. - Tests: replace the global/conflict/cache tests with instance-ownership tests; update the docs to describe per-instance ownership. Independent LLM instances in one process now stay isolated without any fail-fast guard, and the hook instance never crosses an IPC boundary (each owner builds it from the import path). Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- docs/source/features/post-processor-hook.md | 12 +- tensorrt_llm/executor/postproc_worker.py | 27 +++- tensorrt_llm/executor/postprocessor_hook.py | 78 ++-------- tensorrt_llm/executor/result.py | 21 +-- tensorrt_llm/llmapi/llm.py | 60 ++++---- .../executor/test_postprocessor_hook.py | 135 ++++++++++-------- 6 files changed, 154 insertions(+), 179 deletions(-) diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md index 93b9f4e5c06f..932cef6cbeeb 100644 --- a/docs/source/features/post-processor-hook.md +++ b/docs/source/features/post-processor-hook.md @@ -8,8 +8,9 @@ lets a deployment rewrite, redact, suppress, or terminate model output — inclu spans the chunks of a streamed response — without modifying TensorRT LLM source. The hook is a plain Python callable class supplied by import path, mirroring `--custom_tokenizer`. It -is built once per process and invoked once per output, per streaming chunk (plus a final call), so it -can hold its own per-request state. +is owned by the `LLM` instance (and built once in each post-processing worker process when those are +enabled) and invoked once per output, per streaming chunk (plus a final call), so it can hold its own +per-request state. Independent `LLM` instances in one process each own a separate hook instance. ```{note} This feature is a prototype and its interface may change in a future release. @@ -147,9 +148,10 @@ class SuppressHook: ## Per-request state -The hook instance is built once per process and shared across all requests handled by that process, so -any per-request state must be keyed by `chunk.request_id` and released when `chunk.is_final` is seen -(or after a `terminate`). State is not shared across processes; when the post-processing worker pool is +The hook instance is owned by the `LLM` (built once in each post-processing worker process when the +pool is enabled) and shared across all requests it handles, so any per-request state must be keyed by +`chunk.request_id` and released when `chunk.is_final` is seen (or after a `terminate`). State is not +shared across processes or across separate `LLM` instances; when the post-processing worker pool is enabled, all chunks of a single request are still routed to the same worker, so per-request state remains consistent for that request. diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index 069f445b0725..a4cecc11e127 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -14,6 +14,7 @@ from ..logger import logger from ..sampling_params import SamplingParams from .ipc import ZeroMqQueue +from .postprocessor_hook import load_post_processor_hook from .utils import ErrorResponse, is_llm_response if TYPE_CHECKING: @@ -90,6 +91,7 @@ def __init__( tokenizer_dir: str, record_creator: Callable[ ["PostprocWorker.Input", TransformersTokenizer], Any], + post_processor_hook: Optional[str] = None, ): ''' Args: @@ -98,6 +100,7 @@ def __init__( tokenizer_dir (str): The directory to load tokenizer. record_creator (Callable[["ResponsePostprocessWorker.Input"], Any]): A creator for creating a record for a request. result_handler (Optional[Callable[[GenerationResultBase], Any]]): A callback handles the final result. + post_processor_hook (Optional[str]): Import path of the user post-processing hook (TRTLLM-12622). This worker builds and owns one instance, threaded onto each record it creates. ''' self._records: Dict[int, GenerationResult] = {} @@ -118,6 +121,13 @@ def __init__( # Load the tokenizer and share in all records self._tokenizer = load_hf_tokenizer(tokenizer_dir) + # Build the user post-processing hook once (TRTLLM-12622) and own it for + # this worker's lifetime, mirroring the tokenizer above; threaded onto + # each record in ``_handle_input``. None when unconfigured. + self._post_processor_hook = ( + load_post_processor_hook(post_processor_hook) + if post_processor_hook else None) + @staticmethod def default_record_creator( inp: "PostprocWorker.Input", tokenizer: TransformersTokenizer @@ -149,6 +159,12 @@ async def _handle_input( # TODO: support variant creation later self._records[req_id] = self._record_creator( input, self._tokenizer) + # Thread this worker's owned hook onto the record (TRTLLM-12622), + # alongside the tokenizer the record_creator already received, so + # the detok chokepoint applies it. Set here (not in the + # record_creator signature) to keep custom record_creators working. + self._records[ + req_id]._post_processor_hook = self._post_processor_hook if input.disaggregated_params is not None: self._records[ req_id]._disaggregated_params = input.disaggregated_params @@ -282,13 +298,12 @@ def postproc_worker_main(feedin_ipc_addr: tuple[str, Optional[bytes]], tokenizer_dir: str, record_creator: Callable, post_processor_hook: Optional[str] = None): - # Record the configured post-processing hook (TRTLLM-12622) for this worker - # process so the detok chokepoint in DetokenizedGenerationResultBase can - # apply it. - from .postprocessor_hook import set_configured_post_processor_hook - set_configured_post_processor_hook(post_processor_hook) + # The worker owns its post-processing hook instance (TRTLLM-12622): pass the + # import path through to PostprocWorker, which builds it once and threads it + # onto each record for the detok chokepoint in DetokenizedGenerationResultBase. worker = PostprocWorker(feedin_ipc_addr, feedout_ipc_addr, tokenizer_dir=tokenizer_dir, - record_creator=record_creator) + record_creator=record_creator, + post_processor_hook=post_processor_hook) worker.start() diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index a39cf3e054b9..27c3079cbbe2 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -5,9 +5,15 @@ This provides a native, per-request, stateful post-processing seam equivalent to a Triton python-backend post-processor. A user supplies a picklable, importable callable class via the ``--post_processor`` import path; trtllm -builds one instance per process and invokes it once per output, per streaming -chunk (plus a final call), *after* detokenization and *before* the per-endpoint -response formatter. +builds one instance per owner (the ``LLM`` for the in-proxy detok path, and +each post-processing worker process when enabled) and invokes it once per +output, per streaming chunk (plus a final call), *after* detokenization and +*before* the per-endpoint response formatter. + +Ownership is per-instance, not a process global: the hook instance is built by +its owner and threaded onto each result alongside the tokenizer (see +``DetokenizedGenerationResultBase``), so independent ``LLM`` instances in one +process stay isolated. The hook owns its own per-request state (keyed by ``chunk.request_id``) exactly like Triton's model-managed ``self.sequences = {}`` pattern; trtllm passes only @@ -31,58 +37,10 @@ "terminate", "apply_post_processor_hook", "load_post_processor_hook", - "get_post_processor_hook", - "set_configured_post_processor_hook", - "get_configured_post_processor_hook", ] logger = logging.getLogger(__name__) -# Process-level cache so the hook instance is built once per process (mirrors -# the "build once in the worker" precedent of ``record_creator``). Keyed by -# import path; the per-request state lives inside the instance. -_HOOK_INSTANCE_CACHE: dict = {} - -# The post-processing hook is global server config (one pipeline for all -# requests), and detokenization runs in different processes depending on -# ``--num_postprocess_workers`` (the postproc worker process when enabled, the -# proxy/serving process otherwise). Each such process records the configured -# import path here at startup; the detok chokepoint reads it. -_CONFIGURED_HOOK_PATH: Optional[str] = None - - -def set_configured_post_processor_hook(import_path: Optional[str]) -> None: - """Record the configured hook import path for this process (or ``None``). - - The hook is process-global server configuration (one pipeline for all - requests in this process), so registering a *different* non-``None`` hook - while one is already active is rejected: a second ``LLM`` in the same - process with a conflicting ``--post_processor`` would otherwise silently - start applying the wrong hook to the already-running instance's responses. - Re-registering the same path is a no-op, and clearing to ``None`` (e.g. on - shutdown) is always allowed. In normal serving each server/rank process - hosts a single ``LLM``, so this only guards genuinely mixed-LLM processes. - """ - global _CONFIGURED_HOOK_PATH - if ( - import_path is not None - and _CONFIGURED_HOOK_PATH is not None - and import_path != _CONFIGURED_HOOK_PATH - ): - raise RuntimeError( - "A different post-processor hook is already registered in this " - f"process ({_CONFIGURED_HOOK_PATH!r}); refusing to register " - f"{import_path!r}. The post-processing hook is process-global " - "configuration; running multiple LLMs with different " - "--post_processor values in one process is not supported." - ) - _CONFIGURED_HOOK_PATH = import_path - - -def get_configured_post_processor_hook() -> Optional[str]: - """Return the hook import path configured for this process, if any.""" - return _CONFIGURED_HOOK_PATH - def load_post_processor_hook(import_path: str) -> "PostProcessorHook": """Build a post-processor hook instance from a dotted import path. @@ -92,6 +50,11 @@ def load_post_processor_hook(import_path: str) -> "PostProcessorHook": it with no arguments. The class must be importable and picklable so it can cross the post-processing worker process boundary. + Each owner (the ``LLM`` and each postproc worker) calls this once and holds + the returned instance for the lifetime of the owner; the instance is never + pickled across a process boundary (only ``import_path`` is), and per-request + state lives inside it. + Args: import_path: Dotted path to the hook class, e.g. ``'my_pkg.guardrail.MyPostProcessor'``. @@ -115,19 +78,6 @@ def load_post_processor_hook(import_path: str) -> "PostProcessorHook": ) from e -def get_post_processor_hook(import_path: str) -> "PostProcessorHook": - """Return the process-cached hook instance for ``import_path``. - - Builds it on first use and reuses it thereafter so per-request state held by - the instance persists across chunks within this process. - """ - hook = _HOOK_INSTANCE_CACHE.get(import_path) - if hook is None: - hook = load_post_processor_hook(import_path) - _HOOK_INSTANCE_CACHE[import_path] = hook - return hook - - @dataclasses.dataclass class PostProcChunk: """The payload handed to the post-processing hook for one output chunk. diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 3e5a9acb24e6..7314259daff0 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -30,9 +30,7 @@ from ..metrics.perf_utils import \ process_req_perf_metrics as _process_req_perf_metrics from ..sampling_params import LogprobParams, SamplingParams -from .postprocessor_hook import (apply_post_processor_hook, - get_configured_post_processor_hook, - get_post_processor_hook) +from .postprocessor_hook import PostProcessorHook, apply_post_processor_hook from .utils import ErrorResponse, has_event_loop, is_llm_response if TYPE_CHECKING: @@ -817,7 +815,8 @@ def __init__(self, tokenizer: Optional[Callable] = None, streaming: bool = False, background_error_handler: Optional[Callable] = None, - postproc_params: Optional["PostprocParams"] = None): + postproc_params: Optional["PostprocParams"] = None, + post_processor_hook: Optional[PostProcessorHook] = None): super().__init__( id, sampling_params, @@ -826,6 +825,10 @@ def __init__(self, ) self.tokenizer = tokenizer self._streaming = streaming + # User post-processing hook (TRTLLM-12622) owned by this result's + # creator (the LLM for the in-proxy path, the worker for the worker + # path) and threaded in alongside the tokenizer; None when unconfigured. + self._post_processor_hook = post_processor_hook def _handle_response(self, response: "GenerationExecutor.Response"): GenerationResultBase._handle_response(self, response) @@ -887,13 +890,13 @@ def _apply_post_processor_hook(self): Runs after detok populated ``text``/``text_diff`` and before any per-endpoint formatter reads them. Shared by the postproc-worker path - and the in-proxy path; the hook is configured per-process and read from - the process-global set at startup. + and the in-proxy path; the hook instance is owned by this result's + creator and threaded in via ``post_processor_hook`` (``None`` when + unconfigured), so independent ``LLM`` instances stay isolated. """ - import_path = get_configured_post_processor_hook() - if not import_path: + hook = self._post_processor_hook + if hook is None: return - hook = get_post_processor_hook(import_path) apply_post_processor_hook(hook, self, streaming=self._streaming) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 081cb0ebfc97..f76b7d8a1de1 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -32,6 +32,8 @@ GenerationResult, IterationResult, LoRARequest, PostprocWorkerConfig, PromptAdapterRequest) from ..executor.postproc_worker import PostprocParams +from ..executor.postprocessor_hook import (PostProcessorHook, + load_post_processor_hook) from ..executor.request import DEFAULT_REQUEST_PRIORITY from ..executor.utils import (RequestError, create_mpi_comm_session, get_spawn_proxy_process_env) @@ -76,15 +78,21 @@ def __init__(self) -> None: @classmethod def _from_generation_result( - cls, - generation_result: GenerationResult, - prompt: Optional[str] = None, - tokenizer: Optional[TokenizerBase] = None) -> 'RequestOutput': + cls, + generation_result: GenerationResult, + prompt: Optional[str] = None, + tokenizer: Optional[TokenizerBase] = None, + post_processor_hook: Optional[PostProcessorHook] = None + ) -> 'RequestOutput': inst = cls.__new__(cls) inst.__dict__.update(generation_result.__dict__) inst.tokenizer = tokenizer inst._streaming = generation_result._streaming inst._prompt = prompt + # User post-processing hook (TRTLLM-12622) owned by the LLM instance; + # threaded onto the result the user holds (the object the in-proxy detok + # runs on) alongside the tokenizer. None when unconfigured. + inst._post_processor_hook = post_processor_hook return inst @property @@ -230,25 +238,15 @@ def __init__(self, "yellow") self.mpi_session = self.args.mpi_session - # Record the configured post-processing hook (TRTLLM-12622) for this - # (LLM/proxy) process. When postproc workers are disabled, the detok - # chokepoint runs here on RequestOutput; when enabled, each worker - # process records it separately via postproc_worker_main. - from ..executor.postprocessor_hook import ( - get_post_processor_hook, set_configured_post_processor_hook) - _post_processor_hook = getattr(self.args, "post_processor", None) - if _post_processor_hook: - # Resolve eagerly so a bad import path fails fast at startup (and - # primes the per-process build-once cache) rather than erroring on - # the first request. Validate BEFORE recording the process-global so - # a failed import leaves no stale hook registered for this process. - get_post_processor_hook(_post_processor_hook) - # Record the configured hook for this process. Raises if a different - # hook is already registered in-process (mixed-LLM process), so a second - # instance can never silently apply the wrong hook to the first's - # responses. - set_configured_post_processor_hook(_post_processor_hook) - self._post_processor_hook = _post_processor_hook + # Build this LLM's post-processing hook instance (TRTLLM-12622), if + # configured. Ownership is per-instance: this instance is threaded onto + # the results this LLM produces (in-proxy detok path), and each postproc + # worker builds its own from the same import path. Resolving here also + # fails fast on a bad import path at startup rather than per-request. + _post_processor_path = getattr(self.args, "post_processor", None) + self._post_processor_hook = ( + load_post_processor_hook(_post_processor_path) + if _post_processor_path else None) if self.args.parallel_config.is_multi_gpu: if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count( @@ -547,8 +545,11 @@ def generate_async( result.metrics_dict.update( {MetricNames.ARRIVAL_TIMESTAMP: time.time()}) - return RequestOutput._from_generation_result(result, prompt, - self.tokenizer) + return RequestOutput._from_generation_result( + result, + prompt, + self.tokenizer, + post_processor_hook=self._post_processor_hook) def _preprocess( self, @@ -1210,15 +1211,6 @@ def shutdown(self) -> None: self.mpi_session.shutdown() self.mpi_session = None - # Release this process's post-processor hook registration so a - # subsequent LLM with a different --post_processor can be created in the - # same process (e.g. sequential use in tests or notebooks). - if getattr(self, "_post_processor_hook", None): - from ..executor.postprocessor_hook import \ - set_configured_post_processor_hook - set_configured_post_processor_hook(None) - self._post_processor_hook = None - def _check_health(self) -> bool: """Check if the LLM is healthy. diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py index 5b771aa90e7b..8a668568f6a7 100644 --- a/tests/unittest/executor/test_postprocessor_hook.py +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -4,12 +4,10 @@ import pytest -from tensorrt_llm.executor import postprocessor_hook as _pph from tensorrt_llm.executor.postprocessor_hook import ( PostProcChunk, apply_post_processor_hook, emit, - get_post_processor_hook, load_post_processor_hook, suppress, terminate, @@ -17,29 +15,27 @@ from tensorrt_llm.executor.result import CompletionOutput -@pytest.fixture(autouse=True) -def _reset_hook_process_globals(): - """Isolate the module-level process globals between tests.""" - saved_path = _pph.get_configured_post_processor_hook() - saved_cache = dict(_pph._HOOK_INSTANCE_CACHE) - _pph.set_configured_post_processor_hook(None) - _pph._HOOK_INSTANCE_CACHE.clear() - try: - yield - finally: - _pph.set_configured_post_processor_hook(saved_path) - _pph._HOOK_INSTANCE_CACHE.clear() - _pph._HOOK_INSTANCE_CACHE.update(saved_cache) - - class _FakeResult: """Minimal stand-in for a GenerationResult at the detok chokepoint.""" - def __init__(self, outputs, req_id=1, done=False, aborted=False, has_abort=False): + def __init__( + self, + outputs, + req_id=1, + done=False, + aborted=False, + has_abort=False, + streaming=True, + post_processor_hook=None, + ): self.outputs = outputs self.id = req_id self._done = done self._aborted = aborted + self._streaming = streaming + # Per-instance hook ownership (TRTLLM-12622): the detok read site reads + # this attribute rather than a process global. + self._post_processor_hook = post_processor_hook self.abort_called = 0 if has_abort: self.abort = self._abort @@ -226,53 +222,70 @@ def test_loader_raises_on_bad_path(): load_post_processor_hook("no.such.module.Nope") -def test_get_hook_builds_once_per_process(): - a = get_post_processor_hook("collections.OrderedDict") - b = get_post_processor_hook("collections.OrderedDict") - assert a is b +def test_loader_builds_independent_instances(): + """Each owner builds its own instance (no shared process-global cache). + + This is the core of per-instance ownership (TRTLLM-12622): two owners + loading the same import path get distinct instances, so their per-request + state never collides. + """ + a = load_post_processor_hook("collections.OrderedDict") + b = load_post_processor_hook("collections.OrderedDict") + assert a is not b + + +def test_apply_method_reads_hook_from_instance_attribute(): + """The detok read site applies the hook owned by the result instance.""" + from tensorrt_llm.executor.result import DetokenizedGenerationResultBase + + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out], post_processor_hook=lambda c: emit(c.text_diff.upper())) + + # Call the real read site against the per-instance attribute. + DetokenizedGenerationResultBase._apply_post_processor_hook(result) + + assert out.text == "hello WORLD" -def test_configured_hook_global_roundtrip(): - from tensorrt_llm.executor.postprocessor_hook import ( - get_configured_post_processor_hook, - set_configured_post_processor_hook, - ) +def test_apply_method_is_noop_when_instance_has_no_hook(): + """With no hook on the instance, the chunk passes through untouched.""" + from tensorrt_llm.executor.result import DetokenizedGenerationResultBase + + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out], post_processor_hook=None) + + DetokenizedGenerationResultBase._apply_post_processor_hook(result) - try: - assert get_configured_post_processor_hook() is None - set_configured_post_processor_hook("my_pkg.guardrail.G") - assert get_configured_post_processor_hook() == "my_pkg.guardrail.G" - finally: - set_configured_post_processor_hook(None) - assert get_configured_post_processor_hook() is None + assert out.text == "hello world" -def test_configured_hook_rejects_conflicting_registration(): - """A conflicting in-process hook must fail fast (TRTLLM-12622). +def test_independent_instances_keep_separate_state(): + """Two owners' hook instances maintain isolated per-request state. - A second, different hook in the same process is rejected rather than - silently clobbering the global and applying the wrong hook. + Demonstrates the Model-2 guarantee that distinct LLM/worker owners do not + share hook state, even for the same request id. """ - from tensorrt_llm.executor.postprocessor_hook import ( - get_configured_post_processor_hook, - set_configured_post_processor_hook, - ) - - try: - set_configured_post_processor_hook("my_pkg.guardrail.A") - # Re-registering the same path is an idempotent no-op. - set_configured_post_processor_hook("my_pkg.guardrail.A") - assert get_configured_post_processor_hook() == "my_pkg.guardrail.A" - - # A different non-None hook is rejected and the active one is unchanged. - with pytest.raises(RuntimeError, match="already registered"): - set_configured_post_processor_hook("my_pkg.guardrail.B") - assert get_configured_post_processor_hook() == "my_pkg.guardrail.A" - - # Clearing to None is always allowed (e.g. on shutdown), after which a - # new hook can be registered. - set_configured_post_processor_hook(None) - set_configured_post_processor_hook("my_pkg.guardrail.B") - assert get_configured_post_processor_hook() == "my_pkg.guardrail.B" - finally: - set_configured_post_processor_hook(None) + + class Counter: + def __init__(self): + self.state = {} + + def __call__(self, chunk: PostProcChunk): + n = self.state.get(chunk.request_id, 0) + 1 + self.state[chunk.request_id] = n + return emit(f"{chunk.text_diff}#{n}") + + hook_a, hook_b = Counter(), Counter() + # Same request id (1) routed to two different owners. + ra = _FakeResult([_make_output("x", 0)], req_id=1, post_processor_hook=hook_a) + rb = _FakeResult([_make_output("y", 0)], req_id=1, post_processor_hook=hook_b) + + from tensorrt_llm.executor.result import DetokenizedGenerationResultBase + + DetokenizedGenerationResultBase._apply_post_processor_hook(ra) + DetokenizedGenerationResultBase._apply_post_processor_hook(ra) + DetokenizedGenerationResultBase._apply_post_processor_hook(rb) + + # hook_a counted request 1 twice; hook_b counted it once — fully isolated. + assert hook_a.state[1] == 2 + assert hook_b.state[1] == 1 From 4068a2d5f5593fc03204e751e44026b3ad5c17eb Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Mon, 15 Jun 2026 14:26:01 -0700 Subject: [PATCH 08/16] [TRTLLM-12622][chore] Trim verbose comments and a redundant test MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tighten the post-processing hook changes after the per-instance refactor: - postprocessor_hook.py: fix the now-stale "built once per process" wording in the PostProcessorHook docstring (it is per-owner), drop a duplicated ownership paragraph in the module docstring and a redundant token-id NOTE, and shorten the is_final comment. - postproc_worker.py: shorten the hook-injection comment. - post-processor-hook.md: condense the batching note. - test_postprocessor_hook.py: drop test_independent_instances_keep_separate_state — two distinct hook instances having separate state is tautological and the read-site wiring is already covered by the apply-method test. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- docs/source/features/post-processor-hook.md | 8 ++--- tensorrt_llm/executor/postproc_worker.py | 7 ++-- tensorrt_llm/executor/postprocessor_hook.py | 23 ++++--------- .../executor/test_postprocessor_hook.py | 32 ------------------- 4 files changed, 13 insertions(+), 57 deletions(-) diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md index 932cef6cbeeb..d216920334fe 100644 --- a/docs/source/features/post-processor-hook.md +++ b/docs/source/features/post-processor-hook.md @@ -155,11 +155,9 @@ shared across processes or across separate `LLM` instances; when the post-proces enabled, all chunks of a single request are still routed to the same worker, so per-request state remains consistent for that request. -Engine-level batching is transparent to the hook: even when many requests are batched and run together -in the engine, the hook is still invoked **once per request** (per output, per chunk), with -`chunk.request_id` identifying which request the chunk belongs to. There is no batched-call form — the -hook never receives more than one request's data in a single call, so keying state on `request_id` is -sufficient to keep concurrent requests isolated. +Engine-level batching is transparent to the hook: even when requests are batched together in the +engine, the hook is invoked **once per request** (per output, per chunk) — there is no batched-call +form — so keying state on `chunk.request_id` is sufficient to keep concurrent requests isolated. ## Supported endpoints and limitations diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index a4cecc11e127..0b2d1a6d521f 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -159,10 +159,9 @@ async def _handle_input( # TODO: support variant creation later self._records[req_id] = self._record_creator( input, self._tokenizer) - # Thread this worker's owned hook onto the record (TRTLLM-12622), - # alongside the tokenizer the record_creator already received, so - # the detok chokepoint applies it. Set here (not in the - # record_creator signature) to keep custom record_creators working. + # Thread this worker's owned hook onto the record (TRTLLM-12622) + # for the detok chokepoint. Set here, not via the record_creator + # signature, so custom record_creators keep working. self._records[ req_id]._post_processor_hook = self._post_processor_hook if input.disaggregated_params is not None: diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index 27c3079cbbe2..92f93c886824 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -8,12 +8,9 @@ builds one instance per owner (the ``LLM`` for the in-proxy detok path, and each post-processing worker process when enabled) and invokes it once per output, per streaming chunk (plus a final call), *after* detokenization and -*before* the per-endpoint response formatter. - -Ownership is per-instance, not a process global: the hook instance is built by -its owner and threaded onto each result alongside the tokenizer (see -``DetokenizedGenerationResultBase``), so independent ``LLM`` instances in one -process stay isolated. +*before* the per-endpoint response formatter. Ownership is per-instance, not a +process global: the instance is threaded onto each result alongside the +tokenizer, so independent ``LLM`` instances in one process stay isolated. The hook owns its own per-request state (keyed by ``chunk.request_id``) exactly like Triton's model-managed ``self.sequences = {}`` pattern; trtllm passes only @@ -141,12 +138,9 @@ def terminate(reason: str) -> PostProcVerdict: class PostProcessorHook(Protocol): """The interface a user post-processor implements. - The instance is built once per process (its ``__init__`` is the one-time + The instance is built once per owner (its ``__init__`` is the one-time setup) and called once per output, per chunk. It owns any per-request state and is responsible for releasing it on ``chunk.is_final``. - - NOTE: rewriting/suppressing text does not rewrite the underlying token ids; - callers that read both text and token ids should expect them to diverge. """ def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: ... @@ -176,12 +170,9 @@ def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) worker or crash the serving loop. This is consistent across the in-proxy and postproc-worker paths. """ - # ``is_final`` is request-level (``result._done``), not per-output: a - # ``terminate`` verdict on one output marks the whole request done. Under the - # locked 1:1 single-output scope (TRTLLM-12622) this is exact; hooks key their - # per-request state on ``request_id`` and release it on ``is_final``, so - # request-level finality is the correct cleanup signal regardless of output - # count. + # ``is_final`` is request-level (``result._done``), not per-output; under the + # locked 1:1 single-output scope (TRTLLM-12622) this is the exact cleanup + # signal for hooks that release per-request state on ``is_final``. is_final = result._done for output in result.outputs: chunk = PostProcChunk( diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py index 8a668568f6a7..5a94f8ebec1f 100644 --- a/tests/unittest/executor/test_postprocessor_hook.py +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -257,35 +257,3 @@ def test_apply_method_is_noop_when_instance_has_no_hook(): DetokenizedGenerationResultBase._apply_post_processor_hook(result) assert out.text == "hello world" - - -def test_independent_instances_keep_separate_state(): - """Two owners' hook instances maintain isolated per-request state. - - Demonstrates the Model-2 guarantee that distinct LLM/worker owners do not - share hook state, even for the same request id. - """ - - class Counter: - def __init__(self): - self.state = {} - - def __call__(self, chunk: PostProcChunk): - n = self.state.get(chunk.request_id, 0) + 1 - self.state[chunk.request_id] = n - return emit(f"{chunk.text_diff}#{n}") - - hook_a, hook_b = Counter(), Counter() - # Same request id (1) routed to two different owners. - ra = _FakeResult([_make_output("x", 0)], req_id=1, post_processor_hook=hook_a) - rb = _FakeResult([_make_output("y", 0)], req_id=1, post_processor_hook=hook_b) - - from tensorrt_llm.executor.result import DetokenizedGenerationResultBase - - DetokenizedGenerationResultBase._apply_post_processor_hook(ra) - DetokenizedGenerationResultBase._apply_post_processor_hook(ra) - DetokenizedGenerationResultBase._apply_post_processor_hook(rb) - - # hook_a counted request 1 twice; hook_b counted it once — fully isolated. - assert hook_a.state[1] == 2 - assert hook_b.state[1] == 1 From e20bbd151c93a2e522aefa02abeed92fdd32c095 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Mon, 15 Jun 2026 14:45:55 -0700 Subject: [PATCH 09/16] [TRTLLM-12622][fix] Run post-processing hook regardless of detokenize flag Address review of the post-processing hook branch (two-agent review). Main fix: a `/v1/completions` request with `detokenize=false` previously bypassed the hook entirely, because the hook ran inside the detokenize guard. The hook is a server-side guardrail and must not be bypassable by a client flag (matching Triton, where the post-processor is a mandatory ensemble stage). The detok+hook block now runs when detokenize is set OR a post_processor is configured; the returned channel is unchanged (the response formatter honors the request's detokenize flag separately), and suppress/terminate still withhold the token-id channel. Also: - Extract the harmony fail-fast into OpenAIServer._ensure_post_processor_supported and unit-test it directly. - Add an e2e test that detokenize=false does not bypass the terminate hook. - proxy.py / postprocessor_hook.py comment accuracy; postproc_worker.py uses pop(client_id, None) like its hardened siblings. - Docs: not-bypassable guarantee, disagg note, soften the untested responses endpoint claim. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- docs/source/features/post-processor-hook.md | 12 +++++-- tensorrt_llm/executor/postproc_worker.py | 2 +- tensorrt_llm/executor/postprocessor_hook.py | 3 +- tensorrt_llm/executor/proxy.py | 13 ++++---- tensorrt_llm/executor/result.py | 9 ++++- tensorrt_llm/serve/openai_server.py | 33 ++++++++++++------- .../executor/test_postprocessor_hook.py | 17 ++++++++++ .../apps/_test_openai_post_processor.py | 20 +++++++++++ 8 files changed, 86 insertions(+), 23 deletions(-) diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md index d216920334fe..af3e83e28020 100644 --- a/docs/source/features/post-processor-hook.md +++ b/docs/source/features/post-processor-hook.md @@ -161,11 +161,19 @@ form — so keying state on `chunk.request_id` is sufficient to keep concurrent ## Supported endpoints and limitations -- **Endpoints**: `chat/completions` and `completions`, both streaming and non-streaming. The hook also - applies to the `responses` endpoint for non-harmony models. +- **Endpoints**: `chat/completions` and `completions`, both streaming and non-streaming. The hook is + expected to apply to the `responses` endpoint for non-harmony models as well (shared detokenization + path), though that endpoint is not covered by the current end-to-end tests. +- **Not client-bypassable**: the hook is a server-side guardrail, so it runs on every response even + when a `completions` request sets `detokenize=false`. The server detokenizes for the hook regardless; + the `detokenize` flag still controls only the returned channel (text vs. token ids), and a + `suppress`/`terminate` verdict withholds the token-id channel too. - **harmony / gpt-oss models**: not supported. Because the harmony output path is reconstructed from raw token ids, it would bypass the text-based hook, so the server fails fast at startup when `--post_processor` is combined with a harmony model. +- **Disaggregated serving**: the context and generation servers are separate processes, each running + the hook on its own phase under a different `request_id`; per-request state cannot be correlated + across the two. A `terminate` on one phase does not propagate to the other. - **Text vs. token ids**: rewriting or suppressing text does not rewrite the underlying `token_ids` or `logprobs`. Clients that read both should expect them to diverge. - **Reasoning / tool parsing**: the hook runs before the reasoning and tool-call parsers. A hook that diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index 0b2d1a6d521f..ca538890535d 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -247,7 +247,7 @@ async def handle_single_input(inp: PostprocWorker.Input, num_generated_tokens=num_generated_tokens, )) if is_final: - self._records.pop(client_id) + self._records.pop(client_id, None) except Exception as e: logger.error( f"Postprocessing error for client {client_id}: {e}\n" diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index 92f93c886824..bdc4bc0e7928 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -209,7 +209,8 @@ def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) # Cancel the engine request as well. On the in-proxy path this stops # wasted generation; on the worker path the record's abort() only # sets the flag and the engine is cancelled by the proxy via - # should_abort. Guarded for results that expose no abort(). + # should_abort. The getattr guard is defensive (real results always + # define abort()). abort = getattr(result, "abort", None) if callable(abort): try: diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index b9298640cb8f..c9f14555a3bc 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -327,12 +327,13 @@ def process_res(res): nonlocal event_loop nonlocal async_queues - # The result may have already been finalized and removed — e.g. a - # post-processor hook (TRTLLM-12622) returned `terminate`, which - # marks the result done and pops it here, while the engine still has - # in-flight responses for the same client_id (abort is async, and the - # postproc worker recreates a record for any late response). Drop - # such late responses instead of crashing with a KeyError. + # The result may have already been finalized and removed — e.g. on + # the worker path a post-processor hook (TRTLLM-12622) returned + # `terminate`, so the worker sent a final Output that popped this + # client_id below, while the engine still has in-flight responses for + # it (abort is async, and the postproc worker recreates a record for + # any late response). Drop such late responses instead of crashing + # with a KeyError. result = self._results.get(client_id) if result is None: return diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 8fc62a675d2b..0219fb1c734b 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -846,7 +846,14 @@ def _handle_response(self, response: "GenerationExecutor.Response"): 'spaces_between_special_tokens': self.sampling_params.spaces_between_special_tokens } - if self.sampling_params.detokenize and self.tokenizer is not None: + # Detokenize when the client asked for text, OR whenever a post-processing + # hook (TRTLLM-12622) is configured: the hook is a server-side guardrail + # and must run on every response regardless of the client's + # ``detokenize`` flag (which only controls the returned channel, honored + # separately by the response formatter). Without this, a client could + # bypass the guardrail with ``detokenize=False``. + if (self.sampling_params.detokenize or self._post_processor_hook + is not None) and self.tokenizer is not None: for beam_output in self.outputs: beam_output._last_text_len = len(beam_output.text) if hasattr( diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 5dcf4b2995f8..2b438d0daae1 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -403,18 +403,9 @@ def _init_llm(self, chat_template: Optional[str] = None): else: self.use_harmony = (type(self.model_config).model_type == "gpt_oss") - # The harmony (gpt-oss) path rebuilds the client-visible output from raw - # output token ids rather than the detokenized text, so the - # post-processing hook (TRTLLM-12622), which operates on detok text, - # cannot act there. Fail fast rather than silently bypassing a guardrail. - if self.use_harmony and getattr(self.generator.args, "post_processor", - None): - raise ValueError( - "--post_processor is not supported with harmony/gpt-oss models " - "in this version: the harmony output path is reconstructed from " - "raw token ids and would bypass the text-based hook. Disable the " - "hook or set DISABLE_HARMONY_ADAPTER=1 if the harmony path is " - "not needed.") + self._ensure_post_processor_supported( + self.use_harmony, + getattr(self.generator.args, "post_processor", None)) self.tool_call_id_type = "random" # default tool call id type is random if self.model_config is not None: @@ -460,6 +451,24 @@ def _init_llm(self, chat_template: Optional[str] = None): self.perf_metrics = deque(maxlen=max_perf_metrics) self.perf_metrics_lock = asyncio.Lock() + @staticmethod + def _ensure_post_processor_supported(use_harmony: bool, + post_processor: Optional[str]) -> None: + """Reject ``--post_processor`` combined with a harmony/gpt-oss model. + + The harmony output path rebuilds the client-visible output from raw + output token ids rather than the detokenized text, so the text-based + post-processing hook (TRTLLM-12622) cannot act there. Fail fast at + startup rather than silently bypassing a guardrail. + """ + if use_harmony and post_processor: + raise ValueError( + "--post_processor is not supported with harmony/gpt-oss models " + "in this version: the harmony output path is reconstructed from " + "raw token ids and would bypass the text-based hook. Disable the " + "hook or set DISABLE_HARMONY_ADAPTER=1 if the harmony path is " + "not needed.") + def _logit_bias_vocab_size(self) -> int: for config in (self.model_config, getattr(self.model_config, "text_config", None)): diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py index 5a94f8ebec1f..e30c00c6b2ec 100644 --- a/tests/unittest/executor/test_postprocessor_hook.py +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -257,3 +257,20 @@ def test_apply_method_is_noop_when_instance_has_no_hook(): DetokenizedGenerationResultBase._apply_post_processor_hook(result) assert out.text == "hello world" + + +def test_harmony_model_rejects_post_processor(): + """A harmony/gpt-oss model + post_processor must fail fast (TRTLLM-12622). + + The harmony output path is rebuilt from raw token ids and would bypass the + text-based hook, so the server refuses the combination at startup. + """ + from tensorrt_llm.serve.openai_server import OpenAIServer + + guard = OpenAIServer._ensure_post_processor_supported + with pytest.raises(ValueError, match="not supported with harmony"): + guard(use_harmony=True, post_processor="my_pkg.guardrail.Hook") + # Every other combination is allowed. + guard(use_harmony=False, post_processor="my_pkg.guardrail.Hook") + guard(use_harmony=True, post_processor=None) + guard(use_harmony=False, post_processor=None) diff --git a/tests/unittest/llmapi/apps/_test_openai_post_processor.py b/tests/unittest/llmapi/apps/_test_openai_post_processor.py index 0be202efefff..28e86c94ff00 100644 --- a/tests/unittest/llmapi/apps/_test_openai_post_processor.py +++ b/tests/unittest/llmapi/apps/_test_openai_post_processor.py @@ -105,6 +105,26 @@ def test_completions_non_streaming(client: openai.OpenAI, model_name: str, hook: assert completion.choices[0].finish_reason == "stop" +def test_completions_detokenize_false_does_not_bypass_hook( + client: openai.OpenAI, model_name: str, hook: str +): + """A server-side hook is a guardrail and must run even when the client sets + ``detokenize=false`` — that flag controls only the returned channel, not + whether the hook executes (TRTLLM-12622).""" + completion = client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=16, + temperature=0.0, + extra_body={"detokenize": False}, + ) + if hook == "terminate": + # The terminate hook fires on the first chunk regardless of the text + # channel, so the request stops early instead of running to max_tokens. + # If detokenize=false bypassed the hook, this would be "length". + assert completion.choices[0].finish_reason == "stop" + + @pytest.mark.asyncio(loop_scope="module") async def test_completions_streaming(async_client: openai.AsyncOpenAI, model_name: str, hook: str): stream = await async_client.completions.create( From 3699efe18ce421649c44013d9d97b05d7200c52b Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Mon, 15 Jun 2026 15:30:21 -0700 Subject: [PATCH 10/16] [TRTLLM-12622][fix] Withhold all client channels on suppress/terminate MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Close two more guardrail-bypass gaps found in review, by making the hook's withholding contract complete and explicit rather than patching channels one at a time. - All-channel withholding: a channel×endpoint audit showed suppress/terminate only blanked the streaming diff views; the non-streaming formatters emit the FULL token_ids (completions) and FULL logprobs (chat + completions), which still leaked withheld content. _withhold_token_channel now truncates the full token_ids/logprobs to the presented prefix for non-streaming (mirroring how .text is blanked), keeping the proven streaming watermark path unchanged. - Require a tokenizer: post_processor + skip_tokenizer_init is now rejected in BaseLlmArgs (the text-based hook has no text without a tokenizer), mirroring the harmony fail-fast. - Docs: verdict table + limitations clarified (suppress/terminate withhold all channels; emit is text-only; tokenizer required). - Tests: non-streaming full-channel suppress unit test; skip_tokenizer_init rejection test; e2e asserts token_ids are withheld under detokenize=false. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- docs/source/features/post-processor-hook.md | 19 +++++---- tensorrt_llm/executor/postprocessor_hook.py | 39 +++++++++++++------ tensorrt_llm/llmapi/llm_args.py | 10 +++++ .../executor/test_postprocessor_hook.py | 22 +++++++++++ .../apps/_test_openai_post_processor.py | 9 ++++- tests/unittest/llmapi/test_llm_args.py | 14 +++++++ 6 files changed, 94 insertions(+), 19 deletions(-) diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md index af3e83e28020..a1d85912c6ac 100644 --- a/docs/source/features/post-processor-hook.md +++ b/docs/source/features/post-processor-hook.md @@ -83,9 +83,9 @@ Return one of the following helpers from `__call__`: | Helper | Effect | |--------|--------| -| `emit(text)` | Emit `text` for this chunk. Use it to pass output through unchanged (`emit(chunk.text_diff)`) or to rewrite/redact it. | -| `suppress()` | Withhold this chunk entirely (no client-visible output for it). | -| `terminate(reason)` | Stop the stream for this request. `reason` is surfaced as the response `stop_reason`, and the engine request is cancelled. | +| `emit(text)` | Emit `text` for this chunk. Use it to pass output through unchanged (`emit(chunk.text_diff)`) or to rewrite/redact it. Affects the **text** channel only — it does not synthesize matching `token_ids` (see *Text vs. token ids* below). | +| `suppress()` | Withhold this chunk entirely across **all** client-visible channels — text, `token_ids`, and `logprobs` (so `detokenize=false` token output is withheld too). | +| `terminate(reason)` | Stop the stream for this request, withholding the terminating chunk on all channels. `reason` is surfaced as the response `stop_reason`, and the engine request is cancelled. | ## Usage examples @@ -166,16 +166,21 @@ form — so keying state on `chunk.request_id` is sufficient to keep concurrent path), though that endpoint is not covered by the current end-to-end tests. - **Not client-bypassable**: the hook is a server-side guardrail, so it runs on every response even when a `completions` request sets `detokenize=false`. The server detokenizes for the hook regardless; - the `detokenize` flag still controls only the returned channel (text vs. token ids), and a - `suppress`/`terminate` verdict withholds the token-id channel too. + the `detokenize` flag still controls only the returned channel (text vs. token ids). A + `suppress`/`terminate` verdict withholds **all** client-visible channels — text, `token_ids`, and + `logprobs` — on both the streaming and non-streaming paths, so a client cannot recover withheld + content through any channel. +- **Requires a tokenizer**: the hook needs detokenized text to inspect, so `--post_processor` combined + with `skip_tokenizer_init` is rejected at startup rather than silently disabled. - **harmony / gpt-oss models**: not supported. Because the harmony output path is reconstructed from raw token ids, it would bypass the text-based hook, so the server fails fast at startup when `--post_processor` is combined with a harmony model. - **Disaggregated serving**: the context and generation servers are separate processes, each running the hook on its own phase under a different `request_id`; per-request state cannot be correlated across the two. A `terminate` on one phase does not propagate to the other. -- **Text vs. token ids**: rewriting or suppressing text does not rewrite the underlying `token_ids` or - `logprobs`. Clients that read both should expect them to diverge. +- **Text vs. token ids**: `emit` rewrites the **text** channel only — it does not rewrite the underlying + `token_ids`/`logprobs`, so a client reading both should expect them to diverge. (`suppress`/`terminate` + withhold all channels, so they do not diverge.) - **Reasoning / tool parsing**: the hook runs before the reasoning and tool-call parsers. A hook that rewrites or suppresses text may desync those parsers; prefer `terminate`, or apply such hooks to plain-text requests. diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index bdc4bc0e7928..1b2eb8a3af6f 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -146,16 +146,33 @@ class PostProcessorHook(Protocol): def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: ... -def _withhold_token_channel(output) -> None: - """Withhold the raw token-id / logprob channel for this chunk too. - - Without this, a suppressed/terminated chunk would still leak - ``token_ids_diff`` (e.g. on ``/v1/completions`` with ``detokenize=False``) - and ``logprobs_diff`` even though the detokenized text was blanked. +def _withhold_token_channel(output, streaming: bool) -> None: + """Withhold the raw token-id / logprob channels alongside the blanked text. + + ``suppress``/``terminate`` blank the text channel; the raw token-id and + logprob channels must be withheld too, or a suppressed/terminated output + leaks via them — ``token_ids`` on ``/v1/completions`` with + ``detokenize=False``, and ``logprobs`` on both chat and completions. + + The two response shapes withhold differently, matching what each formatter + emits (verified by the channel audit): + + * **streaming** emits per-chunk *diffs* (``token_ids_diff`` / + ``logprobs_diff``); advancing the diff watermark empties this chunk. + * **non-streaming** emits the *full* ``token_ids`` / ``logprobs``; these are + truncated back to the already-presented prefix, mirroring exactly how the + text channel is blanked to ``output.text[:_last_text_len]``. (Outputs + accumulate via ``list.extend`` in the hook's single-output scope, so the + truncation stays consistent across chunks.) """ - output._last_token_ids_len = len(output.token_ids) - if getattr(output, "logprobs", None) is not None: - output._last_logprobs_len = len(output.logprobs) + if streaming: + output._last_token_ids_len = len(output.token_ids) + if getattr(output, "logprobs", None) is not None: + output._last_logprobs_len = len(output.logprobs) + else: + output.token_ids = output.token_ids[: output._last_token_ids_len] + if getattr(output, "logprobs", None) is not None: + output.logprobs = output.logprobs[: output._last_logprobs_len] def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) -> None: @@ -198,10 +215,10 @@ def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) output.text = prefix + verdict.text elif verdict.action == "suppress": output.text = prefix - _withhold_token_channel(output) + _withhold_token_channel(output, streaming) elif verdict.action == "terminate": output.text = prefix + verdict.text - _withhold_token_channel(output) + _withhold_token_channel(output, streaming) output.finish_reason = "stop" output.stop_reason = verdict.reason result._aborted = True diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 139623ed3a28..a492a7716020 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3535,6 +3535,16 @@ def validate_parallel_config(self): def validate_and_init_tokenizer(self): """Initialize tokenizer based on configuration.""" if self.skip_tokenizer_init: + # The post-processing hook (TRTLLM-12622) is a text-based guardrail + # and needs detokenized text to inspect; without a tokenizer it could + # never run, so reject the combination rather than silently disabling + # the guardrail (mirrors the harmony fail-fast in OpenAIServer). + if self.post_processor is not None: + raise ValueError( + "post_processor is not supported together with " + "skip_tokenizer_init: the post-processing hook operates on " + "detokenized text, which is unavailable when the tokenizer " + "is skipped.") self.tokenizer = None elif self.custom_tokenizer: # If tokenizer is already a tokenizer object, custom_tokenizer is not compatible diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py index e30c00c6b2ec..385eee861060 100644 --- a/tests/unittest/executor/test_postprocessor_hook.py +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -110,6 +110,28 @@ def test_suppress_blanks_token_and_logprob_diffs(): assert out.logprobs_diff == [] +def test_suppress_blanks_full_token_and_logprob_channels_non_streaming(): + """Non-streaming suppress must blank the full token/logprob channels (TRTLLM-12622). + + Non-streaming emits the full token_ids/logprobs (not the diff), so suppress + must truncate the full channels — otherwise a detokenize=False completion + (token_ids) or any logprobs response leaks the withheld output. + """ + out = _make_output("the full answer", last_text_len=0) + # Non-streaming single response: the diff watermark stays at 0. + assert out.token_ids == [1, 2, 3] + assert out.logprobs == [-0.1, -0.2, -0.3] + result = _FakeResult([out], done=True) + + apply_post_processor_hook(lambda c: suppress(), result, streaming=False) + + assert out.text == "" + # The full channels the non-streaming formatter reads are emptied, not just + # the diff view. + assert out.token_ids == [] + assert out.logprobs == [] + + def test_terminate_calls_abort_when_available_and_blanks_token_channel(): out = _make_output("safe bad", last_text_len=len("safe ")) out._last_token_ids_len = 1 diff --git a/tests/unittest/llmapi/apps/_test_openai_post_processor.py b/tests/unittest/llmapi/apps/_test_openai_post_processor.py index 28e86c94ff00..85139dcc2cb1 100644 --- a/tests/unittest/llmapi/apps/_test_openai_post_processor.py +++ b/tests/unittest/llmapi/apps/_test_openai_post_processor.py @@ -118,11 +118,18 @@ def test_completions_detokenize_false_does_not_bypass_hook( temperature=0.0, extra_body={"detokenize": False}, ) + choice = completion.choices[0] if hook == "terminate": # The terminate hook fires on the first chunk regardless of the text # channel, so the request stops early instead of running to max_tokens. # If detokenize=false bypassed the hook, this would be "length". - assert completion.choices[0].finish_reason == "stop" + assert choice.finish_reason == "stop" + # The withheld content must not leak through the token-id channel that + # detokenize=false returns. + assert not choice.token_ids + elif hook == "suppress": + # detokenize=false returns token_ids; suppress must withhold them too. + assert not choice.token_ids @pytest.mark.asyncio(loop_scope="module") diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 5b73c0e60f7f..8fe8f6dd5e71 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -224,6 +224,20 @@ def test_decoding_type_eagle3_errors_on_tensorrt_backend(): TrtLlmArgs(model=llama_model_path, speculative_config=spec_cfg) +def test_post_processor_rejected_with_skip_tokenizer_init(): + """post_processor + skip_tokenizer_init must fail fast (TRTLLM-12622). + + The hook is a text-based guardrail; pairing it with skip_tokenizer_init (no + detokenized text) must be rejected rather than silently disabling it. + """ + with pytest.raises(ValidationError, match="skip_tokenizer_init"): + TorchLlmArgs(model="/tmp/dummy_model", + skip_tokenizer_init=True, + post_processor="my_pkg.guardrail.Hook") + # skip_tokenizer_init alone (no hook) is still fine. + TorchLlmArgs(model="/tmp/dummy_model", skip_tokenizer_init=True) + + class TestModelDefaults: """Test suite for model-specific default overrides functionality.""" From 4582a2da9e6a83c481a2154b0527ed716a439ae5 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Mon, 15 Jun 2026 15:47:33 -0700 Subject: [PATCH 11/16] [TRTLLM-12622][doc] Clarify per-chunk withholding semantics MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Third review pass confirmed no bypass: the non-streaming response delivers the content the hook emitted before the first suppress/terminate — identical to what a streaming client receives, so a client cannot obtain more output by choosing non-streaming. Only the wording was imprecise. - _withhold_token_channel docstring: "already-presented" -> "already-emitted" prefix, and note the streaming-consistent semantics. - Doc: state that verdicts are per-chunk (suppress withholds this chunk, terminate keeps prior chunks), consistent across streaming/non-streaming, and that all-or-nothing withholding means suppressing from the first chunk. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- docs/source/features/post-processor-hook.md | 2 ++ tensorrt_llm/executor/postprocessor_hook.py | 11 +++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md index a1d85912c6ac..78a1978d700e 100644 --- a/docs/source/features/post-processor-hook.md +++ b/docs/source/features/post-processor-hook.md @@ -87,6 +87,8 @@ Return one of the following helpers from `__call__`: | `suppress()` | Withhold this chunk entirely across **all** client-visible channels — text, `token_ids`, and `logprobs` (so `detokenize=false` token output is withheld too). | | `terminate(reason)` | Stop the stream for this request, withholding the terminating chunk on all channels. `reason` is surfaced as the response `stop_reason`, and the engine request is cancelled. | +Verdicts are **per chunk**: `suppress()` withholds the current chunk, and `terminate()` stops generation while keeping the chunks already emitted before it. This is consistent across streaming and non-streaming — a non-streaming response contains exactly the content the hook emitted before the first `suppress`/`terminate`, i.e. the same content a streaming client would have received. A hook that must withhold the *entire* output (all-or-nothing) should `suppress()` from the first chunk (it sees every chunk) rather than emitting and then terminating. + ## Usage examples ### Rewrite output diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index 1b2eb8a3af6f..0e612256b3f9 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -160,10 +160,13 @@ def _withhold_token_channel(output, streaming: bool) -> None: * **streaming** emits per-chunk *diffs* (``token_ids_diff`` / ``logprobs_diff``); advancing the diff watermark empties this chunk. * **non-streaming** emits the *full* ``token_ids`` / ``logprobs``; these are - truncated back to the already-presented prefix, mirroring exactly how the - text channel is blanked to ``output.text[:_last_text_len]``. (Outputs - accumulate via ``list.extend`` in the hook's single-output scope, so the - truncation stays consistent across chunks.) + truncated back to the already-*emitted* prefix (the content the hook chose + to emit on prior chunks), mirroring exactly how the text channel is blanked + to ``output.text[:_last_text_len]``. Outputs accumulate via ``list.extend`` + in the hook's single-output scope, so the truncation stays consistent + across chunks: the result holds exactly the content a streaming client + would have received before this ``suppress``/``terminate`` — withholding + this chunk, not retroactively the prior ones. """ if streaming: output._last_token_ids_len = len(output.token_ids) From aa3a324eb1d4785db984f28ef894b3e6fbab1624 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Mon, 15 Jun 2026 16:26:41 -0700 Subject: [PATCH 12/16] [TRTLLM-12622][fix] Fail closed on hook errors; enum verdict actions Make the post-processing hook fail closed: a hook exception (or an invalid verdict) now re-raises so the request errors instead of serving the un-vetted chunk, matching Triton's per-request model. The in-proxy path surfaces it to the serving handler; the worker path converts it to an ErrorResponse. Server and sibling requests stay alive. Replace the free-form verdict action string with a PostProcAction enum validated in PostProcVerdict.__post_init__, so an unknown action can no longer be smuggled past the dispatch. Document the n>1 / beam behavior (emit/suppress are per output sequence; terminate cancels the whole request) instead of assuming a single output, and trim the over-verbose comments and JIRA references across the feature. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- docs/source/features/post-processor-hook.md | 7 + tensorrt_llm/executor/postproc_worker.py | 30 ++--- tensorrt_llm/executor/postprocessor_hook.py | 125 ++++++++---------- tensorrt_llm/executor/proxy.py | 11 +- tensorrt_llm/executor/result.py | 22 ++- tensorrt_llm/llmapi/llm.py | 13 +- tensorrt_llm/llmapi/llm_args.py | 2 +- tensorrt_llm/serve/openai_server.py | 6 +- .../executor/test_postprocessor_hook.py | 32 +++-- .../executor/test_proxy_postproc_terminate.py | 2 +- .../llmapi/apps/_postproc_hook_samples.py | 3 +- .../apps/_test_openai_post_processor.py | 4 +- tests/unittest/llmapi/test_llm_args.py | 2 +- 13 files changed, 120 insertions(+), 139 deletions(-) diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md index 78a1978d700e..47edf227605d 100644 --- a/docs/source/features/post-processor-hook.md +++ b/docs/source/features/post-processor-hook.md @@ -186,6 +186,13 @@ form — so keying state on `chunk.request_id` is sufficient to keep concurrent - **Reasoning / tool parsing**: the hook runs before the reasoning and tool-call parsers. A hook that rewrites or suppresses text may desync those parsers; prefer `terminate`, or apply such hooks to plain-text requests. +- **Hook errors fail closed**: if the hook raises, the request fails with an error rather than serving + the un-vetted chunk (the server and other requests stay alive). A returned verdict with an unknown + action is rejected the same way. +- **`n` > 1 / beam search**: `emit` and `suppress` act per output sequence, but `terminate` cancels the + **whole** request (all sequences), because the engine request is the unit of cancellation — a + `terminate` on one candidate ends the others too. Hooks needing per-sequence state should key on + `(request_id, output_index)`. ## Tests diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index ca538890535d..ece60b2374b1 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -47,10 +47,9 @@ class PostprocWorkerConfig: ''' The config for the postprocess worker. ''' num_postprocess_workers: int = 0 postprocess_tokenizer_dir: Optional[str] = None - # Dotted import path of the user post-processing hook (TRTLLM-12622), or - # None. Propagated into each postproc worker process so the detok chokepoint - # can apply it. NOTE: distinct from ``PostprocParams.post_processor`` above, - # which is the per-endpoint response *formatter* (a Callable), not this hook. + # Dotted import path of the user post-processing hook, or + # None. NOTE: distinct from ``PostprocParams.post_processor``, which is the + # per-endpoint response *formatter* (a Callable), not this hook. post_processor_hook: Optional[str] = None @property @@ -100,7 +99,7 @@ def __init__( tokenizer_dir (str): The directory to load tokenizer. record_creator (Callable[["ResponsePostprocessWorker.Input"], Any]): A creator for creating a record for a request. result_handler (Optional[Callable[[GenerationResultBase], Any]]): A callback handles the final result. - post_processor_hook (Optional[str]): Import path of the user post-processing hook (TRTLLM-12622). This worker builds and owns one instance, threaded onto each record it creates. + post_processor_hook (Optional[str]): Import path of the user post-processing hook; built once and threaded onto each record. ''' self._records: Dict[int, GenerationResult] = {} @@ -121,9 +120,8 @@ def __init__( # Load the tokenizer and share in all records self._tokenizer = load_hf_tokenizer(tokenizer_dir) - # Build the user post-processing hook once (TRTLLM-12622) and own it for - # this worker's lifetime, mirroring the tokenizer above; threaded onto - # each record in ``_handle_input``. None when unconfigured. + # Build the user post-processing hook once, like the + # tokenizer above; threaded onto each record in ``_handle_input``. self._post_processor_hook = ( load_post_processor_hook(post_processor_hook) if post_processor_hook else None) @@ -159,9 +157,8 @@ async def _handle_input( # TODO: support variant creation later self._records[req_id] = self._record_creator( input, self._tokenizer) - # Thread this worker's owned hook onto the record (TRTLLM-12622) - # for the detok chokepoint. Set here, not via the record_creator - # signature, so custom record_creators keep working. + # Thread the hook onto the record here rather than + # via record_creator, so custom record_creators keep working. self._records[ req_id]._post_processor_hook = self._post_processor_hook if input.disaggregated_params is not None: @@ -222,10 +219,9 @@ async def handle_single_input(inp: PostprocWorker.Input, res, metrics, perf_metrics, disaggregated_params = await self._handle_input( inp) record = self._records.get(client_id) - # A post-processing hook that returned `terminate` forces the - # record done (TRTLLM-12622); honor it so the stream stops - # promptly and the record is popped, instead of waiting for the - # engine's own is_final. + # A `terminate` verdict forces the record done; + # honor it so the stream stops and the record is popped without + # waiting for the engine's own is_final. if record is not None and record._done: is_final = True should_abort = record._aborted if record else False @@ -297,9 +293,7 @@ def postproc_worker_main(feedin_ipc_addr: tuple[str, Optional[bytes]], tokenizer_dir: str, record_creator: Callable, post_processor_hook: Optional[str] = None): - # The worker owns its post-processing hook instance (TRTLLM-12622): pass the - # import path through to PostprocWorker, which builds it once and threads it - # onto each record for the detok chokepoint in DetokenizedGenerationResultBase. + # Pass the hook import path; PostprocWorker builds it once. worker = PostprocWorker(feedin_ipc_addr, feedout_ipc_addr, tokenizer_dir=tokenizer_dir, diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index 0e612256b3f9..824a008dea1b 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -2,30 +2,24 @@ # SPDX-License-Identifier: Apache-2.0 """User-pluggable post-processing hook for ``trtllm-serve`` (TRTLLM-12622). -This provides a native, per-request, stateful post-processing seam equivalent -to a Triton python-backend post-processor. A user supplies a picklable, -importable callable class via the ``--post_processor`` import path; trtllm -builds one instance per owner (the ``LLM`` for the in-proxy detok path, and -each post-processing worker process when enabled) and invokes it once per -output, per streaming chunk (plus a final call), *after* detokenization and -*before* the per-endpoint response formatter. Ownership is per-instance, not a -process global: the instance is threaded onto each result alongside the -tokenizer, so independent ``LLM`` instances in one process stay isolated. - -The hook owns its own per-request state (keyed by ``chunk.request_id``) exactly -like Triton's model-managed ``self.sequences = {}`` pattern; trtllm passes only -the request id, the per-chunk payload, lifecycle flags, and the cancel signal. - -This module is intentionally dependency-light (stdlib only) so it can be loaded -in the post-processing worker process and reasoned about in isolation. +A user supplies a picklable, importable callable class via the +``--post_processor`` import path. One instance is built per owner (the ``LLM`` +for the in-proxy detok path, and each post-processing worker process when +enabled) and invoked once per output, per streaming chunk (plus a final call), +*after* detokenization and *before* the per-endpoint response formatter. The +hook owns its per-request state, keyed by ``chunk.request_id``. + +Stdlib-only so it can be loaded in the post-processing worker process. """ import dataclasses +import enum import importlib import logging from typing import List, Optional, Protocol, runtime_checkable __all__ = [ + "PostProcAction", "PostProcChunk", "PostProcVerdict", "PostProcessorHook", @@ -43,21 +37,9 @@ def load_post_processor_hook(import_path: str) -> "PostProcessorHook": """Build a post-processor hook instance from a dotted import path. Mirrors ``tensorrt_llm.tokenizer.load_custom_tokenizer``: resolve - ``module.path.ClassName``, import the module, fetch the class, instantiate - it with no arguments. The class must be importable and picklable so it can - cross the post-processing worker process boundary. - - Each owner (the ``LLM`` and each postproc worker) calls this once and holds - the returned instance for the lifetime of the owner; the instance is never - pickled across a process boundary (only ``import_path`` is), and per-request - state lives inside it. - - Args: - import_path: Dotted path to the hook class, e.g. - ``'my_pkg.guardrail.MyPostProcessor'``. - - Returns: - An instance of the hook class. + ``module.path.ClassName``, import the module, instantiate it with no + arguments. Only ``import_path`` crosses a process boundary (never the + instance), so the class must be importable and picklable. Raises: ValueError: If the path cannot be resolved, imported, or instantiated. @@ -106,6 +88,14 @@ class PostProcChunk: streaming: bool +class PostProcAction(str, enum.Enum): + """The kind of decision a hook returns for one chunk.""" + + EMIT = "emit" + SUPPRESS = "suppress" + TERMINATE = "terminate" + + @dataclasses.dataclass class PostProcVerdict: """The hook's decision for one chunk. @@ -114,24 +104,28 @@ class PostProcVerdict: than constructing this directly. """ - action: str # "emit" | "suppress" | "terminate" + action: PostProcAction text: str = "" reason: Optional[str] = None + def __post_init__(self): + # Coerce/validate so a hook can never smuggle an unknown action. + self.action = PostProcAction(self.action) + def emit(text: str) -> PostProcVerdict: """Emit ``text`` for this chunk (use to rewrite/redact, or pass through).""" - return PostProcVerdict(action="emit", text=text) + return PostProcVerdict(action=PostProcAction.EMIT, text=text) def suppress() -> PostProcVerdict: """Withhold this chunk entirely (no client-visible output).""" - return PostProcVerdict(action="suppress") + return PostProcVerdict(action=PostProcAction.SUPPRESS) def terminate(reason: str) -> PostProcVerdict: """Stop the stream for this request. ``reason`` is surfaced as stop_reason.""" - return PostProcVerdict(action="terminate", reason=reason) + return PostProcVerdict(action=PostProcAction.TERMINATE, reason=reason) @runtime_checkable @@ -149,24 +143,11 @@ def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: ... def _withhold_token_channel(output, streaming: bool) -> None: """Withhold the raw token-id / logprob channels alongside the blanked text. - ``suppress``/``terminate`` blank the text channel; the raw token-id and - logprob channels must be withheld too, or a suppressed/terminated output - leaks via them — ``token_ids`` on ``/v1/completions`` with - ``detokenize=False``, and ``logprobs`` on both chat and completions. - - The two response shapes withhold differently, matching what each formatter - emits (verified by the channel audit): - - * **streaming** emits per-chunk *diffs* (``token_ids_diff`` / - ``logprobs_diff``); advancing the diff watermark empties this chunk. - * **non-streaming** emits the *full* ``token_ids`` / ``logprobs``; these are - truncated back to the already-*emitted* prefix (the content the hook chose - to emit on prior chunks), mirroring exactly how the text channel is blanked - to ``output.text[:_last_text_len]``. Outputs accumulate via ``list.extend`` - in the hook's single-output scope, so the truncation stays consistent - across chunks: the result holds exactly the content a streaming client - would have received before this ``suppress``/``terminate`` — withholding - this chunk, not retroactively the prior ones. + Otherwise a suppressed/terminated output leaks via them (``token_ids`` on + ``/v1/completions`` with ``detokenize=False``, ``logprobs`` on both + endpoints). Streaming emits per-chunk diffs, so advancing the diff watermark + empties this chunk; non-streaming emits the full lists, so truncate them back + to the already-emitted prefix, mirroring how the text is blanked. """ if streaming: output._last_token_ids_len = len(output.token_ids) @@ -185,14 +166,16 @@ def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) (preserving the already-emitted prefix), suppressing it, or terminating the stream via the existing abort machinery. - Hook exceptions are isolated per request: they are logged and the chunk is - passed through unchanged (fail-open), so a buggy hook cannot wedge the - worker or crash the serving loop. This is consistent across the in-proxy and - postproc-worker paths. + A hook exception fails the request closed (re-raised), never serving the + un-vetted chunk: the in-proxy path surfaces it to the serving handler, the + worker path converts it to an ``ErrorResponse``. Both keep the server and + other requests alive (mirrors Triton's per-request fail-closed model). """ - # ``is_final`` is request-level (``result._done``), not per-output; under the - # locked 1:1 single-output scope (TRTLLM-12622) this is the exact cleanup - # signal for hooks that release per-request state on ``is_final``. + # ``is_final`` is request-level (``result._done``): for n>1 / beam it fires + # once when the whole request finishes. emit/suppress act per output, but a + # terminate cancels the whole request (all outputs) because the engine + # request is the unit of cancellation. Hooks needing per-sequence state + # should key on (request_id, output_index). is_final = result._done for output in result.outputs: chunk = PostProcChunk( @@ -209,28 +192,26 @@ def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) verdict = hook(chunk) except Exception: logger.exception( - "Post-processor hook raised for request %s; passing the chunk through unchanged.", + "Post-processor hook failed for request %s; failing the request closed.", result.id, ) - continue + raise prefix = output.text[: output._last_text_len] - if verdict.action == "emit": + if verdict.action is PostProcAction.EMIT: output.text = prefix + verdict.text - elif verdict.action == "suppress": + elif verdict.action is PostProcAction.SUPPRESS: output.text = prefix _withhold_token_channel(output, streaming) - elif verdict.action == "terminate": + elif verdict.action is PostProcAction.TERMINATE: output.text = prefix + verdict.text _withhold_token_channel(output, streaming) output.finish_reason = "stop" output.stop_reason = verdict.reason result._aborted = True result._done = True - # Cancel the engine request as well. On the in-proxy path this stops - # wasted generation; on the worker path the record's abort() only - # sets the flag and the engine is cancelled by the proxy via - # should_abort. The getattr guard is defensive (real results always - # define abort()). + # Cancel the engine request to stop wasted generation (on the worker + # path the proxy does the actual cancel via should_abort). getattr + # guard is defensive; real results always define abort(). abort = getattr(result, "abort", None) if callable(abort): try: @@ -240,4 +221,6 @@ def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) "Failed to abort request %s after terminate verdict.", result.id ) else: - raise ValueError(f"Unknown post-processor verdict action: {verdict.action!r}") + # Unreachable for hook-returned verdicts (validated in + # ``__post_init__``); guards an unhandled future enum member. + raise ValueError(f"Unhandled post-processor action: {verdict.action!r}") diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index c9f14555a3bc..ffe793bc5b48 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -327,13 +327,10 @@ def process_res(res): nonlocal event_loop nonlocal async_queues - # The result may have already been finalized and removed — e.g. on - # the worker path a post-processor hook (TRTLLM-12622) returned - # `terminate`, so the worker sent a final Output that popped this - # client_id below, while the engine still has in-flight responses for - # it (abort is async, and the postproc worker recreates a record for - # any late response). Drop such late responses instead of crashing - # with a KeyError. + # The result may already be finalized and popped below — e.g. a + # `terminate` verdict sent a final Output while the + # engine still has in-flight responses (abort is async). Drop such + # late responses instead of crashing with a KeyError. result = self._results.get(client_id) if result is None: return diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 0219fb1c734b..1d38bbe931b6 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -828,9 +828,8 @@ def __init__(self, ) self.tokenizer = tokenizer self._streaming = streaming - # User post-processing hook (TRTLLM-12622) owned by this result's - # creator (the LLM for the in-proxy path, the worker for the worker - # path) and threaded in alongside the tokenizer; None when unconfigured. + # User post-processing hook, threaded in alongside the + # tokenizer by this result's creator; None when unconfigured. self._post_processor_hook = post_processor_hook def _handle_response(self, response: "GenerationExecutor.Response"): @@ -847,11 +846,9 @@ def _handle_response(self, response: "GenerationExecutor.Response"): self.sampling_params.spaces_between_special_tokens } # Detokenize when the client asked for text, OR whenever a post-processing - # hook (TRTLLM-12622) is configured: the hook is a server-side guardrail - # and must run on every response regardless of the client's - # ``detokenize`` flag (which only controls the returned channel, honored - # separately by the response formatter). Without this, a client could - # bypass the guardrail with ``detokenize=False``. + # hook is configured: the hook runs on every response regardless of the + # client's ``detokenize`` flag, else it could be bypassed with + # ``detokenize=False``. if (self.sampling_params.detokenize or self._post_processor_hook is not None) and self.tokenizer is not None: for beam_output in self.outputs: @@ -896,13 +893,12 @@ def _handle_response(self, response: "GenerationExecutor.Response"): self._apply_post_processor_hook() def _apply_post_processor_hook(self): - """Run the user post-processing hook (TRTLLM-12622) at the detok chokepoint. + """Run the user post-processing hook at the detok chokepoint. Runs after detok populated ``text``/``text_diff`` and before any - per-endpoint formatter reads them. Shared by the postproc-worker path - and the in-proxy path; the hook instance is owned by this result's - creator and threaded in via ``post_processor_hook`` (``None`` when - unconfigured), so independent ``LLM`` instances stay isolated. + per-endpoint formatter reads them. The hook instance is threaded in via + ``post_processor_hook`` (``None`` when unconfigured) so independent + ``LLM`` instances stay isolated. """ hook = self._post_processor_hook if hook is None: diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index f76b7d8a1de1..2e6c80273734 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -89,9 +89,8 @@ def _from_generation_result( inst.tokenizer = tokenizer inst._streaming = generation_result._streaming inst._prompt = prompt - # User post-processing hook (TRTLLM-12622) owned by the LLM instance; - # threaded onto the result the user holds (the object the in-proxy detok - # runs on) alongside the tokenizer. None when unconfigured. + # User post-processing hook; threaded onto the result the + # user holds, where the in-proxy detok runs. None when unconfigured. inst._post_processor_hook = post_processor_hook return inst @@ -238,11 +237,9 @@ def __init__(self, "yellow") self.mpi_session = self.args.mpi_session - # Build this LLM's post-processing hook instance (TRTLLM-12622), if - # configured. Ownership is per-instance: this instance is threaded onto - # the results this LLM produces (in-proxy detok path), and each postproc - # worker builds its own from the same import path. Resolving here also - # fails fast on a bad import path at startup rather than per-request. + # Build this LLM's post-processing hook for the in-proxy detok path (each + # postproc worker builds its own). Resolving here fails fast on a bad + # import path at startup rather than per-request. _post_processor_path = getattr(self.args, "post_processor", None) self._post_processor_hook = ( load_post_processor_hook(_post_processor_path) diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index a492a7716020..c5a0f6518651 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3535,7 +3535,7 @@ def validate_parallel_config(self): def validate_and_init_tokenizer(self): """Initialize tokenizer based on configuration.""" if self.skip_tokenizer_init: - # The post-processing hook (TRTLLM-12622) is a text-based guardrail + # The post-processing hook is a text-based guardrail # and needs detokenized text to inspect; without a tokenizer it could # never run, so reject the combination rather than silently disabling # the guardrail (mirrors the harmony fail-fast in OpenAIServer). diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 2b438d0daae1..cbeea97f79f0 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -456,10 +456,8 @@ def _ensure_post_processor_supported(use_harmony: bool, post_processor: Optional[str]) -> None: """Reject ``--post_processor`` combined with a harmony/gpt-oss model. - The harmony output path rebuilds the client-visible output from raw - output token ids rather than the detokenized text, so the text-based - post-processing hook (TRTLLM-12622) cannot act there. Fail fast at - startup rather than silently bypassing a guardrail. + The harmony output path is rebuilt from raw token ids, not detokenized + text, so the text-based hook cannot act there. """ if use_harmony and post_processor: raise ValueError( diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py index 385eee861060..e548982d3425 100644 --- a/tests/unittest/executor/test_postprocessor_hook.py +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Unit tests for the trtllm-serve post-processing hook (TRTLLM-12622).""" +"""Unit tests for the trtllm-serve post-processing hook.""" import pytest @@ -33,7 +33,7 @@ def __init__( self._done = done self._aborted = aborted self._streaming = streaming - # Per-instance hook ownership (TRTLLM-12622): the detok read site reads + # Per-instance hook ownership: the detok read site reads # this attribute rather than a process global. self._post_processor_hook = post_processor_hook self.abort_called = 0 @@ -111,7 +111,7 @@ def test_suppress_blanks_token_and_logprob_diffs(): def test_suppress_blanks_full_token_and_logprob_channels_non_streaming(): - """Non-streaming suppress must blank the full token/logprob channels (TRTLLM-12622). + """Non-streaming suppress must blank the full token/logprob channels. Non-streaming emits the full token_ids/logprobs (not the diff), so suppress must truncate the full channels — otherwise a detokenize=False completion @@ -155,18 +155,19 @@ def test_terminate_without_abort_attr_does_not_crash(): assert result._done is True -def test_hook_exception_fails_open_passthrough(): +def test_hook_exception_fails_closed_and_reraises(): out = _make_output("hello world", last_text_len=len("hello")) result = _FakeResult([out]) def boom(chunk): raise RuntimeError("hook bug") - # Must not propagate; the chunk passes through unchanged (fail-open). - apply_post_processor_hook(boom, result, streaming=True) + # Fail-closed: the exception propagates so the request errors rather than + # serving the un-vetted chunk. The text is not advanced past the prefix. + with pytest.raises(RuntimeError, match="hook bug"): + apply_post_processor_hook(boom, result, streaming=True) assert out.text == "hello world" - assert out.text_diff == " world" def test_non_streaming_rewrites_full_text(): @@ -224,12 +225,21 @@ def __call__(self, chunk: PostProcChunk): assert 1 not in hook.state -def test_unknown_verdict_action_raises(): +def test_unknown_verdict_action_rejected_at_construction(): + """An unknown action cannot be smuggled: the verdict rejects it on build.""" + from tensorrt_llm.executor.postprocessor_hook import PostProcVerdict + + with pytest.raises(ValueError): + PostProcVerdict(action="bogus") + + +def test_unknown_action_fails_closed_through_apply(): + """A hook constructing a bad verdict fails the request closed (re-raised).""" from tensorrt_llm.executor.postprocessor_hook import PostProcVerdict out = _make_output("x", 0) result = _FakeResult([out]) - with pytest.raises(ValueError, match="Unknown post-processor verdict"): + with pytest.raises(ValueError): apply_post_processor_hook(lambda c: PostProcVerdict(action="bogus"), result, streaming=True) @@ -247,7 +257,7 @@ def test_loader_raises_on_bad_path(): def test_loader_builds_independent_instances(): """Each owner builds its own instance (no shared process-global cache). - This is the core of per-instance ownership (TRTLLM-12622): two owners + This is the core of per-instance ownership: two owners loading the same import path get distinct instances, so their per-request state never collides. """ @@ -282,7 +292,7 @@ def test_apply_method_is_noop_when_instance_has_no_hook(): def test_harmony_model_rejects_post_processor(): - """A harmony/gpt-oss model + post_processor must fail fast (TRTLLM-12622). + """A harmony/gpt-oss model + post_processor must fail fast. The harmony output path is rebuilt from raw token ids and would bypass the text-based hook, so the server refuses the combination at startup. diff --git a/tests/unittest/executor/test_proxy_postproc_terminate.py b/tests/unittest/executor/test_proxy_postproc_terminate.py index 6f03526e5f1f..00b6883c8734 100644 --- a/tests/unittest/executor/test_proxy_postproc_terminate.py +++ b/tests/unittest/executor/test_proxy_postproc_terminate.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """Regression tests for proxy late-response handling after a hook terminate. -When a post-processor hook (TRTLLM-12622) returns ``terminate`` the result is +When a post-processor hook returns ``terminate`` the result is marked done and popped from the proxy's ``_results`` map, but the engine may still have in-flight responses for the same ``client_id`` (abort is async). ``GenerationExecutorProxy.dispatch_result_task`` diff --git a/tests/unittest/llmapi/apps/_postproc_hook_samples.py b/tests/unittest/llmapi/apps/_postproc_hook_samples.py index 24f978a4d7e8..077630ebbd25 100644 --- a/tests/unittest/llmapi/apps/_postproc_hook_samples.py +++ b/tests/unittest/llmapi/apps/_postproc_hook_samples.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""Sample post-processing hooks for the trtllm-serve hook integration test -(TRTLLM-12622). +"""Sample post-processing hooks for the trtllm-serve hook integration test. These are deliberately stateless and deterministic so the test can assert the client-visible effect regardless of the (non-deterministic) model output. Each diff --git a/tests/unittest/llmapi/apps/_test_openai_post_processor.py b/tests/unittest/llmapi/apps/_test_openai_post_processor.py index 85139dcc2cb1..bb2571d75ab7 100644 --- a/tests/unittest/llmapi/apps/_test_openai_post_processor.py +++ b/tests/unittest/llmapi/apps/_test_openai_post_processor.py @@ -1,6 +1,6 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 -"""End-to-end tests for the trtllm-serve post-processing hook (TRTLLM-12622). +"""End-to-end tests for the trtllm-serve post-processing hook. Launches a real ``trtllm-serve`` with ``--post_processor`` pointing at one of the sample hooks in ``_postproc_hook_samples`` and asserts the client-visible @@ -110,7 +110,7 @@ def test_completions_detokenize_false_does_not_bypass_hook( ): """A server-side hook is a guardrail and must run even when the client sets ``detokenize=false`` — that flag controls only the returned channel, not - whether the hook executes (TRTLLM-12622).""" + whether the hook executes.""" completion = client.completions.create( model=model_name, prompt="Hello, my name is", diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 8fe8f6dd5e71..cb4d3bea0d10 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -225,7 +225,7 @@ def test_decoding_type_eagle3_errors_on_tensorrt_backend(): def test_post_processor_rejected_with_skip_tokenizer_init(): - """post_processor + skip_tokenizer_init must fail fast (TRTLLM-12622). + """post_processor + skip_tokenizer_init must fail fast. The hook is a text-based guardrail; pairing it with skip_tokenizer_init (no detokenized text) must be rejected rather than silently disabling it. From f66eb4d09c380047a915cc21cae63f01358d8b3b Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 16 Jun 2026 09:39:33 -0700 Subject: [PATCH 13/16] [TRTLLM-12622][test] Register post-processor hook unit tests in L0 A10 The dedicated unit test files were not enumerated in any test-db list, so CI never collected them. Add them to the executor unittest block in l0_a10.yml so they run in the A10-PyTorch stage. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- tests/integration/test_lists/test-db/l0_a10.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 64d39025d03a..fd7cc5ae2ffd 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -131,6 +131,8 @@ l0_a10: - unittest/executor/test_rpc.py - unittest/executor/test_ipc.py - unittest/executor/test_fatal_error_health_check.py + - unittest/executor/test_postprocessor_hook.py + - unittest/executor/test_proxy_postproc_terminate.py # trtllm-serve CPU-only - unittest/llmapi/apps/test_chat_utils.py - unittest/llmapi/apps/test_tool_parsers.py From 689c8dbd63b1609186214a7afad27a3e1c264cbf Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Mon, 22 Jun 2026 15:09:08 -0700 Subject: [PATCH 14/16] [TRTLLM-12622][fix] Validate hook is callable at load time load_post_processor_hook now checks callable() on the instantiated hook and raises a descriptive ValueError, so a non-callable import path fails at load/startup instead of per-chunk at runtime (benefits CI). Addresses the CodeRabbit suggestion acknowledged on PR #15239 (thread r3425180772). Updates the loader unit tests to use a callable stand-in (MagicMock) and adds a non-callable negative case (OrderedDict). Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- tensorrt_llm/executor/postprocessor_hook.py | 9 ++++++++- .../executor/test_postprocessor_hook.py | 18 ++++++++++++++---- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index 824a008dea1b..61e20cafd69c 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -48,13 +48,20 @@ def load_post_processor_hook(import_path: str) -> "PostProcessorHook": module_path, class_name = import_path.rsplit(".", 1) module = importlib.import_module(module_path) hook_class = getattr(module, class_name) - return hook_class() + hook = hook_class() except (ValueError, ImportError, AttributeError, TypeError) as e: raise ValueError( f"Failed to load post-processor hook '{import_path}': {e}. " "Expected format: 'module.path.ClassName' resolving to a " "no-arg-constructible callable class." ) from e + if not callable(hook): + raise ValueError( + f"Failed to load post-processor hook '{import_path}': resolved " + f"object is not callable (got {type(hook).__name__}); expected an " + "instance implementing __call__(chunk)." + ) + return hook @dataclasses.dataclass diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py index e548982d3425..6cc814489c29 100644 --- a/tests/unittest/executor/test_postprocessor_hook.py +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -244,8 +244,9 @@ def test_unknown_action_fails_closed_through_apply(): def test_loader_resolves_import_path(): - # Any importable, no-arg-constructible class works as a smoke test. - hook = load_post_processor_hook("collections.OrderedDict") + # Any importable, no-arg-constructible callable instance works as a smoke + # test. MagicMock instances are callable, unlike e.g. OrderedDict. + hook = load_post_processor_hook("unittest.mock.MagicMock") assert hook is not None @@ -254,6 +255,15 @@ def test_loader_raises_on_bad_path(): load_post_processor_hook("no.such.module.Nope") +def test_loader_raises_on_non_callable(): + """A non-callable resolved instance fails at load time, not per-chunk. + + OrderedDict instances are not callable, so they exercise this path. + """ + with pytest.raises(ValueError, match="not callable"): + load_post_processor_hook("collections.OrderedDict") + + def test_loader_builds_independent_instances(): """Each owner builds its own instance (no shared process-global cache). @@ -261,8 +271,8 @@ def test_loader_builds_independent_instances(): loading the same import path get distinct instances, so their per-request state never collides. """ - a = load_post_processor_hook("collections.OrderedDict") - b = load_post_processor_hook("collections.OrderedDict") + a = load_post_processor_hook("unittest.mock.MagicMock") + b = load_post_processor_hook("unittest.mock.MagicMock") assert a is not b From 93793de60d68cf3f0eb9871d3675e071cd746944 Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 23 Jun 2026 09:46:37 -0700 Subject: [PATCH 15/16] [TRTLLM-12622][chore] Rename post-processor hook public types, de-duplicate docs Address PR review feedback: - Rename PostProcChunk/PostProcVerdict/PostProcAction to PostProcessorHookChunk/PostProcessorHookVerdict/PostProcessorHookAction so they read distinctly from the existing postproc-worker types (PostprocParams, PostprocWorker, PostprocWorkerConfig). - Replace the duplicated PostProcessorHookChunk field table in the docs with a pointer to the authoritative dataclass to avoid drift. - Use repo-relative paths for the source-file links instead of absolute github.com blob URLs. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- docs/source/features/post-processor-hook.md | 38 ++++++++----------- tensorrt_llm/executor/postprocessor_hook.py | 38 +++++++++---------- .../executor/test_postprocessor_hook.py | 18 +++++---- .../llmapi/apps/_postproc_hook_samples.py | 10 ++--- 4 files changed, 49 insertions(+), 55 deletions(-) diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md index 47edf227605d..0eb1d1b0e27f 100644 --- a/docs/source/features/post-processor-hook.md +++ b/docs/source/features/post-processor-hook.md @@ -17,7 +17,7 @@ This feature is a prototype and its interface may change in a future release. ``` For the interface definitions referenced below, see -[`tensorrt_llm/executor/postprocessor_hook.py`](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/executor/postprocessor_hook.py). +[`tensorrt_llm/executor/postprocessor_hook.py`](../../../tensorrt_llm/executor/postprocessor_hook.py). ## Enabling the hook @@ -49,8 +49,8 @@ A hook implements a single method, `__call__(self, chunk) -> verdict`: ```python from tensorrt_llm.executor.postprocessor_hook import ( - PostProcChunk, - PostProcVerdict, + PostProcessorHookChunk, + PostProcessorHookVerdict, emit, suppress, terminate, @@ -58,24 +58,16 @@ from tensorrt_llm.executor.postprocessor_hook import ( class MyPostProcessor: - def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: return emit(chunk.text_diff) # pass through unchanged ``` -### `PostProcChunk` +### `PostProcessorHookChunk` -The payload handed to the hook for one output chunk: - -| Field | Description | -|-------|-------------| -| `request_id` | Stable identifier for the request; the same value is passed for every chunk of a response, so the hook can key its per-request state on it. | -| `output_index` | Index of the output/beam within the request. | -| `text_diff` | Newly detokenized text produced by this chunk (streaming). For non-streaming requests this equals `text`. | -| `text` | Full accumulated detokenized text so far for this output. | -| `token_ids_diff` | Newly generated token ids for this chunk. | -| `is_final` | `True` on the terminating call for this output. | -| `aborted` | `True` if the request has been marked aborted in this process. Output-side observation only. | -| `streaming` | `True` for streaming requests. | +The payload handed to the hook for one output chunk — `request_id`, `output_index`, `text_diff`, +`text`, `token_ids_diff`, `is_final`, `aborted`, and `streaming`. See the `PostProcessorHookChunk` +dataclass in [`postprocessor_hook.py`](../../../tensorrt_llm/executor/postprocessor_hook.py) +for the authoritative field-by-field descriptions. ### Verdicts @@ -96,11 +88,11 @@ Verdicts are **per chunk**: `suppress()` withholds the current chunk, and `termi A stateless hook that upper-cases every chunk: ```python -from tensorrt_llm.executor.postprocessor_hook import PostProcChunk, PostProcVerdict, emit +from tensorrt_llm.executor.postprocessor_hook import PostProcessorHookChunk, PostProcessorHookVerdict, emit class UpperCaseHook: - def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: return emit(chunk.text_diff.upper()) ``` @@ -111,7 +103,7 @@ phrase appears, and releases its state when the request finishes: ```python from tensorrt_llm.executor.postprocessor_hook import ( - PostProcChunk, PostProcVerdict, emit, terminate, + PostProcessorHookChunk, PostProcessorHookVerdict, emit, terminate, ) @@ -122,7 +114,7 @@ class BannedPhraseGuard: # Per-request accumulators owned entirely by the hook. self._buffers: dict[int, str] = {} - def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: buffer = self._buffers.get(chunk.request_id, "") + chunk.text_diff.lower() self._buffers[chunk.request_id] = buffer @@ -140,11 +132,11 @@ class BannedPhraseGuard: A hook that withholds all client-visible text: ```python -from tensorrt_llm.executor.postprocessor_hook import PostProcChunk, PostProcVerdict, suppress +from tensorrt_llm.executor.postprocessor_hook import PostProcessorHookChunk, PostProcessorHookVerdict, suppress class SuppressHook: - def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: return suppress() ``` diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index 61e20cafd69c..ba02fbd78db6 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -19,9 +19,9 @@ from typing import List, Optional, Protocol, runtime_checkable __all__ = [ - "PostProcAction", - "PostProcChunk", - "PostProcVerdict", + "PostProcessorHookAction", + "PostProcessorHookChunk", + "PostProcessorHookVerdict", "PostProcessorHook", "emit", "suppress", @@ -65,7 +65,7 @@ def load_post_processor_hook(import_path: str) -> "PostProcessorHook": @dataclasses.dataclass -class PostProcChunk: +class PostProcessorHookChunk: """The payload handed to the post-processing hook for one output chunk. Attributes: @@ -95,7 +95,7 @@ class PostProcChunk: streaming: bool -class PostProcAction(str, enum.Enum): +class PostProcessorHookAction(str, enum.Enum): """The kind of decision a hook returns for one chunk.""" EMIT = "emit" @@ -104,35 +104,35 @@ class PostProcAction(str, enum.Enum): @dataclasses.dataclass -class PostProcVerdict: +class PostProcessorHookVerdict: """The hook's decision for one chunk. Use the :func:`emit`, :func:`suppress`, and :func:`terminate` helpers rather than constructing this directly. """ - action: PostProcAction + action: PostProcessorHookAction text: str = "" reason: Optional[str] = None def __post_init__(self): # Coerce/validate so a hook can never smuggle an unknown action. - self.action = PostProcAction(self.action) + self.action = PostProcessorHookAction(self.action) -def emit(text: str) -> PostProcVerdict: +def emit(text: str) -> PostProcessorHookVerdict: """Emit ``text`` for this chunk (use to rewrite/redact, or pass through).""" - return PostProcVerdict(action=PostProcAction.EMIT, text=text) + return PostProcessorHookVerdict(action=PostProcessorHookAction.EMIT, text=text) -def suppress() -> PostProcVerdict: +def suppress() -> PostProcessorHookVerdict: """Withhold this chunk entirely (no client-visible output).""" - return PostProcVerdict(action=PostProcAction.SUPPRESS) + return PostProcessorHookVerdict(action=PostProcessorHookAction.SUPPRESS) -def terminate(reason: str) -> PostProcVerdict: +def terminate(reason: str) -> PostProcessorHookVerdict: """Stop the stream for this request. ``reason`` is surfaced as stop_reason.""" - return PostProcVerdict(action=PostProcAction.TERMINATE, reason=reason) + return PostProcessorHookVerdict(action=PostProcessorHookAction.TERMINATE, reason=reason) @runtime_checkable @@ -144,7 +144,7 @@ class PostProcessorHook(Protocol): and is responsible for releasing it on ``chunk.is_final``. """ - def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: ... + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: ... def _withhold_token_channel(output, streaming: bool) -> None: @@ -185,7 +185,7 @@ def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) # should key on (request_id, output_index). is_final = result._done for output in result.outputs: - chunk = PostProcChunk( + chunk = PostProcessorHookChunk( request_id=result.id, output_index=output.index, text_diff=output.text_diff, @@ -204,12 +204,12 @@ def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) ) raise prefix = output.text[: output._last_text_len] - if verdict.action is PostProcAction.EMIT: + if verdict.action is PostProcessorHookAction.EMIT: output.text = prefix + verdict.text - elif verdict.action is PostProcAction.SUPPRESS: + elif verdict.action is PostProcessorHookAction.SUPPRESS: output.text = prefix _withhold_token_channel(output, streaming) - elif verdict.action is PostProcAction.TERMINATE: + elif verdict.action is PostProcessorHookAction.TERMINATE: output.text = prefix + verdict.text _withhold_token_channel(output, streaming) output.finish_reason = "stop" diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py index 6cc814489c29..528c3514f716 100644 --- a/tests/unittest/executor/test_postprocessor_hook.py +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -5,7 +5,7 @@ import pytest from tensorrt_llm.executor.postprocessor_hook import ( - PostProcChunk, + PostProcessorHookChunk, apply_post_processor_hook, emit, load_post_processor_hook, @@ -57,7 +57,7 @@ def test_rewrite_streaming_diff_replaces_only_the_diff(): out = _make_output("hello world", last_text_len=len("hello")) result = _FakeResult([out]) - def hook(chunk: PostProcChunk): + def hook(chunk: PostProcessorHookChunk): assert chunk.text_diff == " world" assert chunk.text == "hello world" return emit(chunk.text_diff.upper()) @@ -175,7 +175,7 @@ def test_non_streaming_rewrites_full_text(): out = _make_output("the full answer", last_text_len=0) result = _FakeResult([out], done=True) - def hook(chunk: PostProcChunk): + def hook(chunk: PostProcessorHookChunk): assert chunk.text_diff == chunk.text == "the full answer" return emit("REDACTED") @@ -201,7 +201,7 @@ class Counter: def __init__(self): self.state = {} - def __call__(self, chunk: PostProcChunk): + def __call__(self, chunk: PostProcessorHookChunk): n = self.state.get(chunk.request_id, 0) + 1 self.state[chunk.request_id] = n if chunk.is_final: @@ -227,20 +227,22 @@ def __call__(self, chunk: PostProcChunk): def test_unknown_verdict_action_rejected_at_construction(): """An unknown action cannot be smuggled: the verdict rejects it on build.""" - from tensorrt_llm.executor.postprocessor_hook import PostProcVerdict + from tensorrt_llm.executor.postprocessor_hook import PostProcessorHookVerdict with pytest.raises(ValueError): - PostProcVerdict(action="bogus") + PostProcessorHookVerdict(action="bogus") def test_unknown_action_fails_closed_through_apply(): """A hook constructing a bad verdict fails the request closed (re-raised).""" - from tensorrt_llm.executor.postprocessor_hook import PostProcVerdict + from tensorrt_llm.executor.postprocessor_hook import PostProcessorHookVerdict out = _make_output("x", 0) result = _FakeResult([out]) with pytest.raises(ValueError): - apply_post_processor_hook(lambda c: PostProcVerdict(action="bogus"), result, streaming=True) + apply_post_processor_hook( + lambda c: PostProcessorHookVerdict(action="bogus"), result, streaming=True + ) def test_loader_resolves_import_path(): diff --git a/tests/unittest/llmapi/apps/_postproc_hook_samples.py b/tests/unittest/llmapi/apps/_postproc_hook_samples.py index 077630ebbd25..7ae93eb7be45 100644 --- a/tests/unittest/llmapi/apps/_postproc_hook_samples.py +++ b/tests/unittest/llmapi/apps/_postproc_hook_samples.py @@ -10,8 +10,8 @@ class is a top-level, no-arg-constructible, importable callable so it can be """ from tensorrt_llm.executor.postprocessor_hook import ( - PostProcChunk, - PostProcVerdict, + PostProcessorHookChunk, + PostProcessorHookVerdict, emit, suppress, terminate, @@ -21,19 +21,19 @@ class is a top-level, no-arg-constructible, importable callable so it can be class UppercaseHook: """Rewrite every chunk's text to upper case.""" - def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: return emit(chunk.text_diff.upper()) class SuppressHook: """Withhold all output (every chunk is suppressed).""" - def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: return suppress() class TerminateHook: """Terminate the stream immediately on the first chunk seen.""" - def __call__(self, chunk: PostProcChunk) -> PostProcVerdict: + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: return terminate("test_policy") From 5fd660e82f5d067bdfc8c94c35a5282ef834c1af Mon Sep 17 00:00:00 2001 From: Xiao Wang <24860335+xwang233@users.noreply.github.com> Date: Tue, 23 Jun 2026 10:06:35 -0700 Subject: [PATCH 16/16] [TRTLLM-12622][chore] Rename --post_processor flag/field to --post_processor_hook Address PR review feedback (one level up from the type rename): the user-facing CLI flag --post_processor and the BaseLlmArgs field post_processor collided head-on with the pre-existing PostprocParams.post_processor (the internal per-endpoint response formatter), in the same postproc-worker subsystem. Rename the user-facing surface to post_processor_hook, matching the PostProcessorHook* types, the PostProcessorHook Protocol, the internal PostprocWorkerConfig.post_processor_hook, and the doc file name. The existing formatter PostprocParams.post_processor is left untouched. Touches: CLI flag + plumbing (serve.py), BaseLlmArgs field + validation (llm_args.py), eager-load (llm.py), harmony guard rename (openai_server._ensure_post_processor_hook_supported), api_stability reference, unit/e2e tests, and docs. Signed-off-by: Xiao Wang <24860335+xwang233@users.noreply.github.com> --- docs/source/features/post-processor-hook.md | 12 ++++---- tensorrt_llm/commands/serve.py | 30 ++++++++++--------- tensorrt_llm/executor/postprocessor_hook.py | 2 +- tensorrt_llm/llmapi/llm.py | 6 ++-- tensorrt_llm/llmapi/llm_args.py | 8 ++--- tensorrt_llm/serve/openai_server.py | 14 ++++----- .../api_stability/references/llm.yaml | 2 +- .../executor/test_postprocessor_hook.py | 14 ++++----- .../llmapi/apps/_postproc_hook_samples.py | 2 +- .../apps/_test_openai_post_processor.py | 4 +-- tests/unittest/llmapi/test_llm_args.py | 6 ++-- 11 files changed, 51 insertions(+), 49 deletions(-) diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md index 0eb1d1b0e27f..3046dc4c8045 100644 --- a/docs/source/features/post-processor-hook.md +++ b/docs/source/features/post-processor-hook.md @@ -21,16 +21,16 @@ For the interface definitions referenced below, see ## Enabling the hook -Pass the dotted import path of your hook class to `--post_processor`: +Pass the dotted import path of your hook class to `--post_processor_hook`: ```bash -trtllm-serve --post_processor my_pkg.guardrail.MyPostProcessor +trtllm-serve --post_processor_hook my_pkg.guardrail.MyPostProcessorHook ``` Equivalently, set it in a YAML config passed via `--extra_llm_api_options`: ```yaml -post_processor: my_pkg.guardrail.MyPostProcessor +post_processor_hook: my_pkg.guardrail.MyPostProcessorHook ``` The class must be: @@ -57,7 +57,7 @@ from tensorrt_llm.executor.postprocessor_hook import ( ) -class MyPostProcessor: +class MyPostProcessorHook: def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: return emit(chunk.text_diff) # pass through unchanged ``` @@ -164,11 +164,11 @@ form — so keying state on `chunk.request_id` is sufficient to keep concurrent `suppress`/`terminate` verdict withholds **all** client-visible channels — text, `token_ids`, and `logprobs` — on both the streaming and non-streaming paths, so a client cannot recover withheld content through any channel. -- **Requires a tokenizer**: the hook needs detokenized text to inspect, so `--post_processor` combined +- **Requires a tokenizer**: the hook needs detokenized text to inspect, so `--post_processor_hook` combined with `skip_tokenizer_init` is rejected at startup rather than silently disabled. - **harmony / gpt-oss models**: not supported. Because the harmony output path is reconstructed from raw token ids, it would bypass the text-based hook, so the server fails fast at startup when - `--post_processor` is combined with a harmony model. + `--post_processor_hook` is combined with a harmony model. - **Disaggregated serving**: the context and generation servers are separate processes, each running the hook on its own phase under a different `request_id`; per-request state cannot be correlated across the two. A `terminate` on one phase does not propagate to the other. diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index 1f4847df7bd5..c4686aa6477c 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -131,8 +131,8 @@ def is_non_default_or_required(param_name, value, backend, explicit_cli_keys): 3. Different from its default value in the backend's LlmArgs class """ always_include = { - "model", "backend", "tokenizer", "custom_tokenizer", "post_processor", - "postprocess_tokenizer_dir" + "model", "backend", "tokenizer", "custom_tokenizer", + "post_processor_hook", "postprocess_tokenizer_dir" } if param_name in always_include: @@ -181,7 +181,7 @@ def get_llm_args( model: str, tokenizer: Optional[str] = None, custom_tokenizer: Optional[str] = None, - post_processor: Optional[str] = None, + post_processor_hook: Optional[str] = None, backend: str = "pytorch", max_beam_width: int = BuildConfig.model_fields["max_beam_width"]. default, @@ -239,8 +239,8 @@ def get_llm_args( tokenizer, "custom_tokenizer": custom_tokenizer, - "post_processor": - post_processor, + "post_processor_hook": + post_processor_hook, "postprocess_tokenizer_dir": tokenizer or model, "kv_cache_config": @@ -638,16 +638,18 @@ def convert(self, value: Any, param: Optional["click.Parameter"], "(e.g., 'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer').", "prototype")) @click.option( - "--post_processor", + "--post_processor_hook", type=str, default=None, help=help_info_with_stability_tag( "Python import path of a user post-processing hook applied after " - "detokenization and before the response formatter (e.g. " - "'my_pkg.guardrail.MyPostProcessor'). The class must be importable, " - "picklable, take no constructor arguments, and be callable per chunk; " - "it may rewrite, suppress, or terminate the output and owns its own " - "per-request state.", "prototype")) + "detokenization and before the per-endpoint response formatter (e.g. " + "'my_pkg.guardrail.MyPostProcessorHook'). The class must be importable " + "and picklable, take no constructor arguments, and be callable as " + "'__call__(chunk) -> verdict' (see tensorrt_llm.executor.postprocessor_hook). " + "It runs once per output, per streaming chunk, and may rewrite, " + "suppress, or terminate the output; it owns its own per-request state.", + "prototype")) @click.option("--host", type=str, default="localhost", @@ -926,8 +928,8 @@ def convert(self, value: Any, param: Optional["click.Parameter"], "Types of agents to schedule. Now Only Support Open Deep Research agent.") def serve( model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str], - post_processor: Optional[str], host: str, port: int, log_level: str, - backend: str, max_beam_width: int, max_batch_size: int, + post_processor_hook: Optional[str], host: str, port: int, + log_level: str, backend: str, max_beam_width: int, max_batch_size: int, max_num_tokens: int, max_seq_len: int, tensor_parallel_size: int, pipeline_parallel_size: int, context_parallel_size: int, moe_expert_parallel_size: Optional[int], @@ -1011,7 +1013,7 @@ def _serve_llm(): model=model, tokenizer=tokenizer, custom_tokenizer=custom_tokenizer, - post_processor=post_processor, + post_processor_hook=post_processor_hook, backend=backend, max_beam_width=max_beam_width, max_batch_size=max_batch_size, diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py index ba02fbd78db6..953c4362189c 100644 --- a/tensorrt_llm/executor/postprocessor_hook.py +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -3,7 +3,7 @@ """User-pluggable post-processing hook for ``trtllm-serve`` (TRTLLM-12622). A user supplies a picklable, importable callable class via the -``--post_processor`` import path. One instance is built per owner (the ``LLM`` +``--post_processor_hook`` import path. One instance is built per owner (the ``LLM`` for the in-proxy detok path, and each post-processing worker process when enabled) and invoked once per output, per streaming chunk (plus a final call), *after* detokenization and *before* the per-endpoint response formatter. The diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index dfaf1d1bb7bf..aac9d8d7cf99 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -317,7 +317,7 @@ def __init__(self, # Build this LLM's post-processing hook for the in-proxy detok path (each # postproc worker builds its own). Resolving here fails fast on a bad # import path at startup rather than per-request. - _post_processor_path = getattr(self.args, "post_processor", None) + _post_processor_path = getattr(self.args, "post_processor_hook", None) self._post_processor_hook = ( load_post_processor_hook(_post_processor_path) if _post_processor_path else None) @@ -1699,7 +1699,7 @@ def _build_model(self): postproc_worker_config=PostprocWorkerConfig( num_postprocess_workers=self.args.num_postprocess_workers, postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir, - post_processor_hook=self.args.post_processor, + post_processor_hook=self.args.post_processor_hook, ), is_llm_executor=True) @@ -1848,7 +1848,7 @@ def _build_model(self): postproc_worker_config=PostprocWorkerConfig( num_postprocess_workers=self.args.num_postprocess_workers, postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir, - post_processor_hook=self.args.post_processor, + post_processor_hook=self.args.post_processor_hook, ), is_llm_executor=True, hf_model_dir=self._hf_model_dir, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 4242110cddbb..31c101ef96a4 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3394,12 +3394,12 @@ class BaseLlmArgs(StrictBaseModel): "The tokenizer class must implement 'from_pretrained(path, **kwargs)' and the TokenizerBase interface.", status="prototype") - post_processor: Optional[str] = Field( + post_processor_hook: Optional[str] = Field( default=None, description= "Python import path of a user post-processing hook applied after " "detokenization and before the per-endpoint response formatter (e.g. " - "'my_pkg.guardrail.MyPostProcessor'). The class must be importable and " + "'my_pkg.guardrail.MyPostProcessorHook'). The class must be importable and " "picklable, take no constructor arguments, and be callable as " "'__call__(chunk) -> verdict' (see tensorrt_llm.executor.postprocessor_hook). " "It runs once per output, per streaming chunk, and may rewrite, " @@ -3764,9 +3764,9 @@ def validate_and_init_tokenizer(self): # and needs detokenized text to inspect; without a tokenizer it could # never run, so reject the combination rather than silently disabling # the guardrail (mirrors the harmony fail-fast in OpenAIServer). - if self.post_processor is not None: + if self.post_processor_hook is not None: raise ValueError( - "post_processor is not supported together with " + "post_processor_hook is not supported together with " "skip_tokenizer_init: the post-processing hook operates on " "detokenized text, which is unavailable when the tokenizer " "is skipped.") diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 4326bb76ab2c..94de0a424f01 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -403,9 +403,9 @@ def _init_llm(self, chat_template: Optional[str] = None): else: self.use_harmony = (type(self.model_config).model_type == "gpt_oss") - self._ensure_post_processor_supported( + self._ensure_post_processor_hook_supported( self.use_harmony, - getattr(self.generator.args, "post_processor", None)) + getattr(self.generator.args, "post_processor_hook", None)) self.tool_call_id_type = "random" # default tool call id type is random if self.model_config is not None: @@ -452,16 +452,16 @@ def _init_llm(self, chat_template: Optional[str] = None): self.perf_metrics_lock = asyncio.Lock() @staticmethod - def _ensure_post_processor_supported(use_harmony: bool, - post_processor: Optional[str]) -> None: - """Reject ``--post_processor`` combined with a harmony/gpt-oss model. + def _ensure_post_processor_hook_supported( + use_harmony: bool, post_processor_hook: Optional[str]) -> None: + """Reject ``--post_processor_hook`` combined with a harmony/gpt-oss model. The harmony output path is rebuilt from raw token ids, not detokenized text, so the text-based hook cannot act there. """ - if use_harmony and post_processor: + if use_harmony and post_processor_hook: raise ValueError( - "--post_processor is not supported with harmony/gpt-oss models " + "--post_processor_hook is not supported with harmony/gpt-oss models " "in this version: the harmony output path is reconstructed from " "raw token ids and would bypass the text-based hook. Disable the " "hook or set DISABLE_HARMONY_ADAPTER=1 if the harmony path is " diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 6756161cb1d5..072affe92c19 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -64,7 +64,7 @@ methods: annotation: Optional[str] default: null status: prototype - post_processor: + post_processor_hook: annotation: Optional[str] default: null status: prototype diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py index 528c3514f716..581453466aa5 100644 --- a/tests/unittest/executor/test_postprocessor_hook.py +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -303,18 +303,18 @@ def test_apply_method_is_noop_when_instance_has_no_hook(): assert out.text == "hello world" -def test_harmony_model_rejects_post_processor(): - """A harmony/gpt-oss model + post_processor must fail fast. +def test_harmony_model_rejects_post_processor_hook(): + """A harmony/gpt-oss model + post_processor_hook must fail fast. The harmony output path is rebuilt from raw token ids and would bypass the text-based hook, so the server refuses the combination at startup. """ from tensorrt_llm.serve.openai_server import OpenAIServer - guard = OpenAIServer._ensure_post_processor_supported + guard = OpenAIServer._ensure_post_processor_hook_supported with pytest.raises(ValueError, match="not supported with harmony"): - guard(use_harmony=True, post_processor="my_pkg.guardrail.Hook") + guard(use_harmony=True, post_processor_hook="my_pkg.guardrail.Hook") # Every other combination is allowed. - guard(use_harmony=False, post_processor="my_pkg.guardrail.Hook") - guard(use_harmony=True, post_processor=None) - guard(use_harmony=False, post_processor=None) + guard(use_harmony=False, post_processor_hook="my_pkg.guardrail.Hook") + guard(use_harmony=True, post_processor_hook=None) + guard(use_harmony=False, post_processor_hook=None) diff --git a/tests/unittest/llmapi/apps/_postproc_hook_samples.py b/tests/unittest/llmapi/apps/_postproc_hook_samples.py index 7ae93eb7be45..53e02fce1d60 100644 --- a/tests/unittest/llmapi/apps/_postproc_hook_samples.py +++ b/tests/unittest/llmapi/apps/_postproc_hook_samples.py @@ -5,7 +5,7 @@ These are deliberately stateless and deterministic so the test can assert the client-visible effect regardless of the (non-deterministic) model output. Each class is a top-level, no-arg-constructible, importable callable so it can be -supplied to ``trtllm-serve --post_processor`` and reconstructed by reference in +supplied to ``trtllm-serve --post_processor_hook`` and reconstructed by reference in the post-processing worker process. """ diff --git a/tests/unittest/llmapi/apps/_test_openai_post_processor.py b/tests/unittest/llmapi/apps/_test_openai_post_processor.py index bb2571d75ab7..e4cd279727a4 100644 --- a/tests/unittest/llmapi/apps/_test_openai_post_processor.py +++ b/tests/unittest/llmapi/apps/_test_openai_post_processor.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 """End-to-end tests for the trtllm-serve post-processing hook. -Launches a real ``trtllm-serve`` with ``--post_processor`` pointing at one of +Launches a real ``trtllm-serve`` with ``--post_processor_hook`` pointing at one of the sample hooks in ``_postproc_hook_samples`` and asserts the client-visible effect (rewrite / suppress / terminate) across the chat and completions endpoints, streaming and non-streaming, with the postproc worker pool both @@ -59,7 +59,7 @@ def server(model_name: str, num_postprocess_workers: int, hook: str): "0.2", "--num_postprocess_workers", f"{num_postprocess_workers}", - "--post_processor", + "--post_processor_hook", _HOOKS[hook], ] # Make the sample-hook module importable by the server (and its postproc diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 38dc8a8cef50..ba66a033b83e 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -226,8 +226,8 @@ def test_decoding_type_eagle3_errors_on_tensorrt_backend(): TrtLlmArgs(model=llama_model_path, speculative_config=spec_cfg) -def test_post_processor_rejected_with_skip_tokenizer_init(): - """post_processor + skip_tokenizer_init must fail fast. +def test_post_processor_hook_rejected_with_skip_tokenizer_init(): + """post_processor_hook + skip_tokenizer_init must fail fast. The hook is a text-based guardrail; pairing it with skip_tokenizer_init (no detokenized text) must be rejected rather than silently disabling it. @@ -235,7 +235,7 @@ def test_post_processor_rejected_with_skip_tokenizer_init(): with pytest.raises(ValidationError, match="skip_tokenizer_init"): TorchLlmArgs(model="/tmp/dummy_model", skip_tokenizer_init=True, - post_processor="my_pkg.guardrail.Hook") + post_processor_hook="my_pkg.guardrail.Hook") # skip_tokenizer_init alone (no hook) is still fine. TorchLlmArgs(model="/tmp/dummy_model", skip_tokenizer_init=True)