diff --git a/docs/source/features/post-processor-hook.md b/docs/source/features/post-processor-hook.md new file mode 100644 index 000000000000..3046dc4c8045 --- /dev/null +++ b/docs/source/features/post-processor-hook.md @@ -0,0 +1,201 @@ +(post-processor-hook)= + +# Post-Processing Hook + +`trtllm-serve` supports a user-supplied **post-processing hook**: a native, per-request seam that runs +on each generated output *after* detokenization and *before* the per-endpoint response formatter. It +lets a deployment rewrite, redact, suppress, or terminate model output — including stateful logic that +spans the chunks of a streamed response — without modifying TensorRT LLM source. + +The hook is a plain Python callable class supplied by import path, mirroring `--custom_tokenizer`. It +is owned by the `LLM` instance (and built once in each post-processing worker process when those are +enabled) and invoked once per output, per streaming chunk (plus a final call), so it can hold its own +per-request state. Independent `LLM` instances in one process each own a separate hook instance. + +```{note} +This feature is a prototype and its interface may change in a future release. +``` + +For the interface definitions referenced below, see +[`tensorrt_llm/executor/postprocessor_hook.py`](../../../tensorrt_llm/executor/postprocessor_hook.py). + +## Enabling the hook + +Pass the dotted import path of your hook class to `--post_processor_hook`: + +```bash +trtllm-serve --post_processor_hook my_pkg.guardrail.MyPostProcessorHook +``` + +Equivalently, set it in a YAML config passed via `--extra_llm_api_options`: + +```yaml +post_processor_hook: my_pkg.guardrail.MyPostProcessorHook +``` + +The class must be: + +- **Importable** — installed (`pip install`) or on `PYTHONPATH` when the server (and its + post-processing worker processes, if any) start. +- **Picklable and no-argument-constructible** — the hook is reconstructed by reference inside each + process; `__init__` takes no required arguments and is the one-time setup point. + +The hook works the same with or without the out-of-process post-processing worker pool +(`--num_postprocess_workers`). + +## The hook interface + +A hook implements a single method, `__call__(self, chunk) -> verdict`: + +```python +from tensorrt_llm.executor.postprocessor_hook import ( + PostProcessorHookChunk, + PostProcessorHookVerdict, + emit, + suppress, + terminate, +) + + +class MyPostProcessorHook: + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: + return emit(chunk.text_diff) # pass through unchanged +``` + +### `PostProcessorHookChunk` + +The payload handed to the hook for one output chunk — `request_id`, `output_index`, `text_diff`, +`text`, `token_ids_diff`, `is_final`, `aborted`, and `streaming`. See the `PostProcessorHookChunk` +dataclass in [`postprocessor_hook.py`](../../../tensorrt_llm/executor/postprocessor_hook.py) +for the authoritative field-by-field descriptions. + +### Verdicts + +Return one of the following helpers from `__call__`: + +| Helper | Effect | +|--------|--------| +| `emit(text)` | Emit `text` for this chunk. Use it to pass output through unchanged (`emit(chunk.text_diff)`) or to rewrite/redact it. Affects the **text** channel only — it does not synthesize matching `token_ids` (see *Text vs. token ids* below). | +| `suppress()` | Withhold this chunk entirely across **all** client-visible channels — text, `token_ids`, and `logprobs` (so `detokenize=false` token output is withheld too). | +| `terminate(reason)` | Stop the stream for this request, withholding the terminating chunk on all channels. `reason` is surfaced as the response `stop_reason`, and the engine request is cancelled. | + +Verdicts are **per chunk**: `suppress()` withholds the current chunk, and `terminate()` stops generation while keeping the chunks already emitted before it. This is consistent across streaming and non-streaming — a non-streaming response contains exactly the content the hook emitted before the first `suppress`/`terminate`, i.e. the same content a streaming client would have received. A hook that must withhold the *entire* output (all-or-nothing) should `suppress()` from the first chunk (it sees every chunk) rather than emitting and then terminating. + +## Usage examples + +### Rewrite output + +A stateless hook that upper-cases every chunk: + +```python +from tensorrt_llm.executor.postprocessor_hook import PostProcessorHookChunk, PostProcessorHookVerdict, emit + + +class UpperCaseHook: + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: + return emit(chunk.text_diff.upper()) +``` + +### Stateful guardrail that terminates on a banned phrase + +This hook accumulates text per request (keyed by `request_id`), stops the stream as soon as a banned +phrase appears, and releases its state when the request finishes: + +```python +from tensorrt_llm.executor.postprocessor_hook import ( + PostProcessorHookChunk, PostProcessorHookVerdict, emit, terminate, +) + + +class BannedPhraseGuard: + _BANNED = ("forbidden phrase",) + + def __init__(self): + # Per-request accumulators owned entirely by the hook. + self._buffers: dict[int, str] = {} + + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: + buffer = self._buffers.get(chunk.request_id, "") + chunk.text_diff.lower() + self._buffers[chunk.request_id] = buffer + + if any(phrase in buffer for phrase in self._BANNED): + self._buffers.pop(chunk.request_id, None) + return terminate("banned_phrase") + + if chunk.is_final: + self._buffers.pop(chunk.request_id, None) + return emit(chunk.text_diff) +``` + +### Suppress output + +A hook that withholds all client-visible text: + +```python +from tensorrt_llm.executor.postprocessor_hook import PostProcessorHookChunk, PostProcessorHookVerdict, suppress + + +class SuppressHook: + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: + return suppress() +``` + +## Per-request state + +The hook instance is owned by the `LLM` (built once in each post-processing worker process when the +pool is enabled) and shared across all requests it handles, so any per-request state must be keyed by +`chunk.request_id` and released when `chunk.is_final` is seen (or after a `terminate`). State is not +shared across processes or across separate `LLM` instances; when the post-processing worker pool is +enabled, all chunks of a single request are still routed to the same worker, so per-request state +remains consistent for that request. + +Engine-level batching is transparent to the hook: even when requests are batched together in the +engine, the hook is invoked **once per request** (per output, per chunk) — there is no batched-call +form — so keying state on `chunk.request_id` is sufficient to keep concurrent requests isolated. + +## Supported endpoints and limitations + +- **Endpoints**: `chat/completions` and `completions`, both streaming and non-streaming. The hook is + expected to apply to the `responses` endpoint for non-harmony models as well (shared detokenization + path), though that endpoint is not covered by the current end-to-end tests. +- **Not client-bypassable**: the hook is a server-side guardrail, so it runs on every response even + when a `completions` request sets `detokenize=false`. The server detokenizes for the hook regardless; + the `detokenize` flag still controls only the returned channel (text vs. token ids). A + `suppress`/`terminate` verdict withholds **all** client-visible channels — text, `token_ids`, and + `logprobs` — on both the streaming and non-streaming paths, so a client cannot recover withheld + content through any channel. +- **Requires a tokenizer**: the hook needs detokenized text to inspect, so `--post_processor_hook` combined + with `skip_tokenizer_init` is rejected at startup rather than silently disabled. +- **harmony / gpt-oss models**: not supported. Because the harmony output path is reconstructed from + raw token ids, it would bypass the text-based hook, so the server fails fast at startup when + `--post_processor_hook` is combined with a harmony model. +- **Disaggregated serving**: the context and generation servers are separate processes, each running + the hook on its own phase under a different `request_id`; per-request state cannot be correlated + across the two. A `terminate` on one phase does not propagate to the other. +- **Text vs. token ids**: `emit` rewrites the **text** channel only — it does not rewrite the underlying + `token_ids`/`logprobs`, so a client reading both should expect them to diverge. (`suppress`/`terminate` + withhold all channels, so they do not diverge.) +- **Reasoning / tool parsing**: the hook runs before the reasoning and tool-call parsers. A hook that + rewrites or suppresses text may desync those parsers; prefer `terminate`, or apply such hooks to + plain-text requests. +- **Hook errors fail closed**: if the hook raises, the request fails with an error rather than serving + the un-vetted chunk (the server and other requests stay alive). A returned verdict with an unknown + action is rejected the same way. +- **`n` > 1 / beam search**: `emit` and `suppress` act per output sequence, but `terminate` cancels the + **whole** request (all sequences), because the engine request is the unit of cancellation — a + `terminate` on one candidate ends the others too. Hooks needing per-sequence state should key on + `(request_id, output_index)`. + +## Tests + +The hook's unit and end-to-end tests double as runnable usage examples: + +```bash +# Unit tests for the hook contract (rewrite / suppress / terminate, per-request state, loader) +pytest tests/unittest/executor/test_postprocessor_hook.py -v +``` + +- End-to-end serving tests across endpoints, streaming modes, and worker-pool settings: + `tests/unittest/llmapi/apps/_test_openai_post_processor.py`. +- The deterministic sample hooks used by those tests: + `tests/unittest/llmapi/apps/_postproc_hook_samples.py`. diff --git a/docs/source/index.rst b/docs/source/index.rst index 9d92a8770148..8fc90f39cf1b 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -75,6 +75,7 @@ Welcome to TensorRT LLM's Documentation! features/quantization.md features/sampling.md features/additional-outputs.md + features/post-processor-hook.md features/guided-decoding.md features/speculative-decoding.md features/checkpoint-loading.md diff --git a/tensorrt_llm/commands/serve.py b/tensorrt_llm/commands/serve.py index aee6c37781d1..c4686aa6477c 100644 --- a/tensorrt_llm/commands/serve.py +++ b/tensorrt_llm/commands/serve.py @@ -132,7 +132,7 @@ def is_non_default_or_required(param_name, value, backend, explicit_cli_keys): """ always_include = { "model", "backend", "tokenizer", "custom_tokenizer", - "postprocess_tokenizer_dir" + "post_processor_hook", "postprocess_tokenizer_dir" } if param_name in always_include: @@ -181,6 +181,7 @@ def get_llm_args( model: str, tokenizer: Optional[str] = None, custom_tokenizer: Optional[str] = None, + post_processor_hook: Optional[str] = None, backend: str = "pytorch", max_beam_width: int = BuildConfig.model_fields["max_beam_width"]. default, @@ -238,6 +239,8 @@ def get_llm_args( tokenizer, "custom_tokenizer": custom_tokenizer, + "post_processor_hook": + post_processor_hook, "postprocess_tokenizer_dir": tokenizer or model, "kv_cache_config": @@ -634,6 +637,19 @@ def convert(self, value: Any, param: Optional["click.Parameter"], "Custom tokenizer type: alias (e.g., 'deepseek_v32') or Python import path " "(e.g., 'tensorrt_llm.tokenizer.deepseek_v32.DeepseekV32Tokenizer').", "prototype")) +@click.option( + "--post_processor_hook", + type=str, + default=None, + help=help_info_with_stability_tag( + "Python import path of a user post-processing hook applied after " + "detokenization and before the per-endpoint response formatter (e.g. " + "'my_pkg.guardrail.MyPostProcessorHook'). The class must be importable " + "and picklable, take no constructor arguments, and be callable as " + "'__call__(chunk) -> verdict' (see tensorrt_llm.executor.postprocessor_hook). " + "It runs once per output, per streaming chunk, and may rewrite, " + "suppress, or terminate the output; it owns its own per-request state.", + "prototype")) @click.option("--host", type=str, default="localhost", @@ -912,10 +928,11 @@ def convert(self, value: Any, param: Optional["click.Parameter"], "Types of agents to schedule. Now Only Support Open Deep Research agent.") def serve( model: str, tokenizer: Optional[str], custom_tokenizer: Optional[str], - host: str, port: int, log_level: str, backend: str, max_beam_width: int, - max_batch_size: int, max_num_tokens: int, max_seq_len: int, - tensor_parallel_size: int, pipeline_parallel_size: int, - context_parallel_size: int, moe_expert_parallel_size: Optional[int], + post_processor_hook: Optional[str], host: str, port: int, + log_level: str, backend: str, max_beam_width: int, max_batch_size: int, + max_num_tokens: int, max_seq_len: int, tensor_parallel_size: int, + pipeline_parallel_size: int, context_parallel_size: int, + moe_expert_parallel_size: Optional[int], moe_cluster_parallel_size: Optional[int], gpus_per_node: Optional[int], free_gpu_memory_fraction: float, kv_cache_dtype: str, num_postprocess_workers: int, trust_remote_code: bool, @@ -996,6 +1013,7 @@ def _serve_llm(): model=model, tokenizer=tokenizer, custom_tokenizer=custom_tokenizer, + post_processor_hook=post_processor_hook, backend=backend, max_beam_width=max_beam_width, max_batch_size=max_batch_size, diff --git a/tensorrt_llm/executor/postproc_worker.py b/tensorrt_llm/executor/postproc_worker.py index 4b7200d2238d..ece60b2374b1 100644 --- a/tensorrt_llm/executor/postproc_worker.py +++ b/tensorrt_llm/executor/postproc_worker.py @@ -14,6 +14,7 @@ from ..logger import logger from ..sampling_params import SamplingParams from .ipc import ZeroMqQueue +from .postprocessor_hook import load_post_processor_hook from .utils import ErrorResponse, is_llm_response if TYPE_CHECKING: @@ -46,6 +47,10 @@ class PostprocWorkerConfig: ''' The config for the postprocess worker. ''' num_postprocess_workers: int = 0 postprocess_tokenizer_dir: Optional[str] = None + # Dotted import path of the user post-processing hook, or + # None. NOTE: distinct from ``PostprocParams.post_processor``, which is the + # per-endpoint response *formatter* (a Callable), not this hook. + post_processor_hook: Optional[str] = None @property def enabled(self) -> bool: @@ -85,6 +90,7 @@ def __init__( tokenizer_dir: str, record_creator: Callable[ ["PostprocWorker.Input", TransformersTokenizer], Any], + post_processor_hook: Optional[str] = None, ): ''' Args: @@ -93,6 +99,7 @@ def __init__( tokenizer_dir (str): The directory to load tokenizer. record_creator (Callable[["ResponsePostprocessWorker.Input"], Any]): A creator for creating a record for a request. result_handler (Optional[Callable[[GenerationResultBase], Any]]): A callback handles the final result. + post_processor_hook (Optional[str]): Import path of the user post-processing hook; built once and threaded onto each record. ''' self._records: Dict[int, GenerationResult] = {} @@ -113,6 +120,12 @@ def __init__( # Load the tokenizer and share in all records self._tokenizer = load_hf_tokenizer(tokenizer_dir) + # Build the user post-processing hook once, like the + # tokenizer above; threaded onto each record in ``_handle_input``. + self._post_processor_hook = ( + load_post_processor_hook(post_processor_hook) + if post_processor_hook else None) + @staticmethod def default_record_creator( inp: "PostprocWorker.Input", tokenizer: TransformersTokenizer @@ -144,6 +157,10 @@ async def _handle_input( # TODO: support variant creation later self._records[req_id] = self._record_creator( input, self._tokenizer) + # Thread the hook onto the record here rather than + # via record_creator, so custom record_creators keep working. + self._records[ + req_id]._post_processor_hook = self._post_processor_hook if input.disaggregated_params is not None: self._records[ req_id]._disaggregated_params = input.disaggregated_params @@ -202,6 +219,11 @@ async def handle_single_input(inp: PostprocWorker.Input, res, metrics, perf_metrics, disaggregated_params = await self._handle_input( inp) record = self._records.get(client_id) + # A `terminate` verdict forces the record done; + # honor it so the stream stops and the record is popped without + # waiting for the engine's own is_final. + if record is not None and record._done: + is_final = True should_abort = record._aborted if record else False finish_reason = record.outputs[0].finish_reason if ( record and record.outputs @@ -221,7 +243,7 @@ async def handle_single_input(inp: PostprocWorker.Input, num_generated_tokens=num_generated_tokens, )) if is_final: - self._records.pop(client_id) + self._records.pop(client_id, None) except Exception as e: logger.error( f"Postprocessing error for client {client_id}: {e}\n" @@ -268,9 +290,13 @@ async def main(): @print_traceback_on_error def postproc_worker_main(feedin_ipc_addr: tuple[str, Optional[bytes]], feedout_ipc_addr: tuple[str, Optional[bytes]], - tokenizer_dir: str, record_creator: Callable): + tokenizer_dir: str, + record_creator: Callable, + post_processor_hook: Optional[str] = None): + # Pass the hook import path; PostprocWorker builds it once. worker = PostprocWorker(feedin_ipc_addr, feedout_ipc_addr, tokenizer_dir=tokenizer_dir, - record_creator=record_creator) + record_creator=record_creator, + post_processor_hook=post_processor_hook) worker.start() diff --git a/tensorrt_llm/executor/postprocessor_hook.py b/tensorrt_llm/executor/postprocessor_hook.py new file mode 100644 index 000000000000..953c4362189c --- /dev/null +++ b/tensorrt_llm/executor/postprocessor_hook.py @@ -0,0 +1,233 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""User-pluggable post-processing hook for ``trtllm-serve`` (TRTLLM-12622). + +A user supplies a picklable, importable callable class via the +``--post_processor_hook`` import path. One instance is built per owner (the ``LLM`` +for the in-proxy detok path, and each post-processing worker process when +enabled) and invoked once per output, per streaming chunk (plus a final call), +*after* detokenization and *before* the per-endpoint response formatter. The +hook owns its per-request state, keyed by ``chunk.request_id``. + +Stdlib-only so it can be loaded in the post-processing worker process. +""" + +import dataclasses +import enum +import importlib +import logging +from typing import List, Optional, Protocol, runtime_checkable + +__all__ = [ + "PostProcessorHookAction", + "PostProcessorHookChunk", + "PostProcessorHookVerdict", + "PostProcessorHook", + "emit", + "suppress", + "terminate", + "apply_post_processor_hook", + "load_post_processor_hook", +] + +logger = logging.getLogger(__name__) + + +def load_post_processor_hook(import_path: str) -> "PostProcessorHook": + """Build a post-processor hook instance from a dotted import path. + + Mirrors ``tensorrt_llm.tokenizer.load_custom_tokenizer``: resolve + ``module.path.ClassName``, import the module, instantiate it with no + arguments. Only ``import_path`` crosses a process boundary (never the + instance), so the class must be importable and picklable. + + Raises: + ValueError: If the path cannot be resolved, imported, or instantiated. + """ + try: + module_path, class_name = import_path.rsplit(".", 1) + module = importlib.import_module(module_path) + hook_class = getattr(module, class_name) + hook = hook_class() + except (ValueError, ImportError, AttributeError, TypeError) as e: + raise ValueError( + f"Failed to load post-processor hook '{import_path}': {e}. " + "Expected format: 'module.path.ClassName' resolving to a " + "no-arg-constructible callable class." + ) from e + if not callable(hook): + raise ValueError( + f"Failed to load post-processor hook '{import_path}': resolved " + f"object is not callable (got {type(hook).__name__}); expected an " + "instance implementing __call__(chunk)." + ) + return hook + + +@dataclasses.dataclass +class PostProcessorHookChunk: + """The payload handed to the post-processing hook for one output chunk. + + Attributes: + request_id: Stable identifier for the request; the same value is passed + for every chunk of a given response, so the hook can key its own + per-request state on it. + output_index: Index of the output/beam within the request. + text_diff: Newly detokenized text produced by this chunk (streaming). + For non-streaming requests this equals ``text``. + text: Full accumulated detokenized text so far for this output. + token_ids_diff: Newly generated token ids for this chunk. + is_final: True on the terminating call for this output. + aborted: True if the request has been marked aborted in this process + (e.g. a prior ``terminate`` verdict, or an abort observed by the + detok process). Output-side observation only; do not rely on it to + detect every upstream client cancellation. + streaming: True for streaming requests. + """ + + request_id: int + output_index: int + text_diff: str + text: str + token_ids_diff: List[int] + is_final: bool + aborted: bool + streaming: bool + + +class PostProcessorHookAction(str, enum.Enum): + """The kind of decision a hook returns for one chunk.""" + + EMIT = "emit" + SUPPRESS = "suppress" + TERMINATE = "terminate" + + +@dataclasses.dataclass +class PostProcessorHookVerdict: + """The hook's decision for one chunk. + + Use the :func:`emit`, :func:`suppress`, and :func:`terminate` helpers rather + than constructing this directly. + """ + + action: PostProcessorHookAction + text: str = "" + reason: Optional[str] = None + + def __post_init__(self): + # Coerce/validate so a hook can never smuggle an unknown action. + self.action = PostProcessorHookAction(self.action) + + +def emit(text: str) -> PostProcessorHookVerdict: + """Emit ``text`` for this chunk (use to rewrite/redact, or pass through).""" + return PostProcessorHookVerdict(action=PostProcessorHookAction.EMIT, text=text) + + +def suppress() -> PostProcessorHookVerdict: + """Withhold this chunk entirely (no client-visible output).""" + return PostProcessorHookVerdict(action=PostProcessorHookAction.SUPPRESS) + + +def terminate(reason: str) -> PostProcessorHookVerdict: + """Stop the stream for this request. ``reason`` is surfaced as stop_reason.""" + return PostProcessorHookVerdict(action=PostProcessorHookAction.TERMINATE, reason=reason) + + +@runtime_checkable +class PostProcessorHook(Protocol): + """The interface a user post-processor implements. + + The instance is built once per owner (its ``__init__`` is the one-time + setup) and called once per output, per chunk. It owns any per-request state + and is responsible for releasing it on ``chunk.is_final``. + """ + + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: ... + + +def _withhold_token_channel(output, streaming: bool) -> None: + """Withhold the raw token-id / logprob channels alongside the blanked text. + + Otherwise a suppressed/terminated output leaks via them (``token_ids`` on + ``/v1/completions`` with ``detokenize=False``, ``logprobs`` on both + endpoints). Streaming emits per-chunk diffs, so advancing the diff watermark + empties this chunk; non-streaming emits the full lists, so truncate them back + to the already-emitted prefix, mirroring how the text is blanked. + """ + if streaming: + output._last_token_ids_len = len(output.token_ids) + if getattr(output, "logprobs", None) is not None: + output._last_logprobs_len = len(output.logprobs) + else: + output.token_ids = output.token_ids[: output._last_token_ids_len] + if getattr(output, "logprobs", None) is not None: + output.logprobs = output.logprobs[: output._last_logprobs_len] + + +def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) -> None: + """Run ``hook`` over ``result.outputs`` in place at the detok chokepoint. + + Applies each verdict by rewriting the chunk's text diff on the output + (preserving the already-emitted prefix), suppressing it, or terminating the + stream via the existing abort machinery. + + A hook exception fails the request closed (re-raised), never serving the + un-vetted chunk: the in-proxy path surfaces it to the serving handler, the + worker path converts it to an ``ErrorResponse``. Both keep the server and + other requests alive (mirrors Triton's per-request fail-closed model). + """ + # ``is_final`` is request-level (``result._done``): for n>1 / beam it fires + # once when the whole request finishes. emit/suppress act per output, but a + # terminate cancels the whole request (all outputs) because the engine + # request is the unit of cancellation. Hooks needing per-sequence state + # should key on (request_id, output_index). + is_final = result._done + for output in result.outputs: + chunk = PostProcessorHookChunk( + request_id=result.id, + output_index=output.index, + text_diff=output.text_diff, + text=output.text, + token_ids_diff=list(output.token_ids_diff), + is_final=is_final, + aborted=result._aborted, + streaming=streaming, + ) + try: + verdict = hook(chunk) + except Exception: + logger.exception( + "Post-processor hook failed for request %s; failing the request closed.", + result.id, + ) + raise + prefix = output.text[: output._last_text_len] + if verdict.action is PostProcessorHookAction.EMIT: + output.text = prefix + verdict.text + elif verdict.action is PostProcessorHookAction.SUPPRESS: + output.text = prefix + _withhold_token_channel(output, streaming) + elif verdict.action is PostProcessorHookAction.TERMINATE: + output.text = prefix + verdict.text + _withhold_token_channel(output, streaming) + output.finish_reason = "stop" + output.stop_reason = verdict.reason + result._aborted = True + result._done = True + # Cancel the engine request to stop wasted generation (on the worker + # path the proxy does the actual cancel via should_abort). getattr + # guard is defensive; real results always define abort(). + abort = getattr(result, "abort", None) + if callable(abort): + try: + abort() + except Exception: + logger.exception( + "Failed to abort request %s after terminate verdict.", result.id + ) + else: + # Unreachable for hook-returned verdicts (validated in + # ``__post_init__``); guards an unhandled future enum member. + raise ValueError(f"Unhandled post-processor action: {verdict.action!r}") diff --git a/tensorrt_llm/executor/proxy.py b/tensorrt_llm/executor/proxy.py index a0cb59111c9d..ffe793bc5b48 100644 --- a/tensorrt_llm/executor/proxy.py +++ b/tensorrt_llm/executor/proxy.py @@ -327,7 +327,14 @@ def process_res(res): nonlocal event_loop nonlocal async_queues - queue = self._results[client_id].queue + # The result may already be finalized and popped below — e.g. a + # `terminate` verdict sent a final Output while the + # engine still has in-flight responses (abort is async). Drop such + # late responses instead of crashing with a KeyError. + result = self._results.get(client_id) + if result is None: + return + queue = result.queue if isinstance(queue, _SyncQueue): queue.put_nowait(res) async_queues.append(queue) @@ -342,7 +349,7 @@ def process_res(res): res, ErrorResponse) or (isinstance(res, PostprocWorker.Output) and res.is_final): - self._results.pop(client_id) + self._results.pop(client_id, None) res = res if isinstance(res, list) else [res] diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index 97b71cb7af0b..4a16ee76587f 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -30,6 +30,7 @@ from ..metrics.perf_utils import \ process_req_perf_metrics as _process_req_perf_metrics from ..sampling_params import LogprobParams, SamplingParams +from .postprocessor_hook import PostProcessorHook, apply_post_processor_hook from .utils import ErrorResponse, has_event_loop, is_llm_response if TYPE_CHECKING: @@ -828,7 +829,8 @@ def __init__(self, tokenizer: Optional[Callable] = None, streaming: bool = False, background_error_handler: Optional[Callable] = None, - postproc_params: Optional["PostprocParams"] = None): + postproc_params: Optional["PostprocParams"] = None, + post_processor_hook: Optional[PostProcessorHook] = None): super().__init__( id, sampling_params, @@ -837,6 +839,9 @@ def __init__(self, ) self.tokenizer = tokenizer self._streaming = streaming + # User post-processing hook, threaded in alongside the + # tokenizer by this result's creator; None when unconfigured. + self._post_processor_hook = post_processor_hook def _handle_response(self, response: "GenerationExecutor.Response"): GenerationResultBase._handle_response(self, response) @@ -851,7 +856,12 @@ def _handle_response(self, response: "GenerationExecutor.Response"): 'spaces_between_special_tokens': self.sampling_params.spaces_between_special_tokens } - if self.sampling_params.detokenize and self.tokenizer is not None: + # Detokenize when the client asked for text, OR whenever a post-processing + # hook is configured: the hook runs on every response regardless of the + # client's ``detokenize`` flag, else it could be bypassed with + # ``detokenize=False``. + if (self.sampling_params.detokenize or self._post_processor_hook + is not None) and self.tokenizer is not None: for beam_output in self.outputs: beam_output._last_text_len = len(beam_output.text) if hasattr( @@ -891,6 +901,21 @@ def _handle_response(self, response: "GenerationExecutor.Response"): self._done = True break + self._apply_post_processor_hook() + + def _apply_post_processor_hook(self): + """Run the user post-processing hook at the detok chokepoint. + + Runs after detok populated ``text``/``text_diff`` and before any + per-endpoint formatter reads them. The hook instance is threaded in via + ``post_processor_hook`` (``None`` when unconfigured) so independent + ``LLM`` instances stay isolated. + """ + hook = self._post_processor_hook + if hook is None: + return + apply_post_processor_hook(hook, self, streaming=self._streaming) + # alias PostprocWorker = DetokenizedGenerationResultBase.PostprocWorker diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index 62960200ca26..e39533dad5ed 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -281,6 +281,7 @@ def notify_proxy_threads_to_quit(): proxy_result_queue, postproc_worker_config.postprocess_tokenizer_dir, PostprocWorker.default_record_creator, + postproc_worker_config.post_processor_hook, ) postprocess_worker_futures.append(fut) diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index 2659d46449e2..aac9d8d7cf99 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -32,6 +32,8 @@ GenerationResult, IterationResult, LoRARequest, PostprocWorkerConfig, PromptAdapterRequest) from ..executor.postproc_worker import PostprocParams +from ..executor.postprocessor_hook import (PostProcessorHook, + load_post_processor_hook) from ..executor.request import DEFAULT_REQUEST_PRIORITY from ..executor.utils import (RequestError, create_mpi_comm_session, get_spawn_proxy_process_env) @@ -76,15 +78,20 @@ def __init__(self) -> None: @classmethod def _from_generation_result( - cls, - generation_result: GenerationResult, - prompt: Optional[str] = None, - tokenizer: Optional[TokenizerBase] = None) -> 'RequestOutput': + cls, + generation_result: GenerationResult, + prompt: Optional[str] = None, + tokenizer: Optional[TokenizerBase] = None, + post_processor_hook: Optional[PostProcessorHook] = None + ) -> 'RequestOutput': inst = cls.__new__(cls) inst.__dict__.update(generation_result.__dict__) inst.tokenizer = tokenizer inst._streaming = generation_result._streaming inst._prompt = prompt + # User post-processing hook; threaded onto the result the + # user holds, where the in-proxy detok runs. None when unconfigured. + inst._post_processor_hook = post_processor_hook return inst @property @@ -307,6 +314,14 @@ def __init__(self, "yellow") self.mpi_session = self.args.mpi_session + # Build this LLM's post-processing hook for the in-proxy detok path (each + # postproc worker builds its own). Resolving here fails fast on a bad + # import path at startup rather than per-request. + _post_processor_path = getattr(self.args, "post_processor_hook", None) + self._post_processor_hook = ( + load_post_processor_hook(_post_processor_path) + if _post_processor_path else None) + if self.args.parallel_config.is_multi_gpu: if os.getenv("RAY_LOCAL_WORLD_SIZE") is None and get_device_count( ) < self.args.parallel_config.world_size_per_node: @@ -681,8 +696,11 @@ def generate_async( result.metrics_dict.update( {MetricNames.ARRIVAL_TIMESTAMP: time.time()}) - return RequestOutput._from_generation_result(result, prompt, - self.tokenizer) + return RequestOutput._from_generation_result( + result, + prompt, + self.tokenizer, + post_processor_hook=self._post_processor_hook) def _preprocess( self, @@ -1681,6 +1699,7 @@ def _build_model(self): postproc_worker_config=PostprocWorkerConfig( num_postprocess_workers=self.args.num_postprocess_workers, postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir, + post_processor_hook=self.args.post_processor_hook, ), is_llm_executor=True) @@ -1829,6 +1848,7 @@ def _build_model(self): postproc_worker_config=PostprocWorkerConfig( num_postprocess_workers=self.args.num_postprocess_workers, postprocess_tokenizer_dir=self.args.postprocess_tokenizer_dir, + post_processor_hook=self.args.post_processor_hook, ), is_llm_executor=True, hf_model_dir=self._hf_model_dir, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 609e11c5a27a..31c101ef96a4 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3394,6 +3394,18 @@ class BaseLlmArgs(StrictBaseModel): "The tokenizer class must implement 'from_pretrained(path, **kwargs)' and the TokenizerBase interface.", status="prototype") + post_processor_hook: Optional[str] = Field( + default=None, + description= + "Python import path of a user post-processing hook applied after " + "detokenization and before the per-endpoint response formatter (e.g. " + "'my_pkg.guardrail.MyPostProcessorHook'). The class must be importable and " + "picklable, take no constructor arguments, and be callable as " + "'__call__(chunk) -> verdict' (see tensorrt_llm.executor.postprocessor_hook). " + "It runs once per output, per streaming chunk, and may rewrite, " + "suppress, or terminate the output; it owns its own per-request state.", + status="prototype") + skip_tokenizer_init: bool = Field( default=False, description="Whether to skip the tokenizer initialization.") @@ -3748,6 +3760,16 @@ def validate_parallel_config(self): def validate_and_init_tokenizer(self): """Initialize tokenizer based on configuration.""" if self.skip_tokenizer_init: + # The post-processing hook is a text-based guardrail + # and needs detokenized text to inspect; without a tokenizer it could + # never run, so reject the combination rather than silently disabling + # the guardrail (mirrors the harmony fail-fast in OpenAIServer). + if self.post_processor_hook is not None: + raise ValueError( + "post_processor_hook is not supported together with " + "skip_tokenizer_init: the post-processing hook operates on " + "detokenized text, which is unavailable when the tokenizer " + "is skipped.") self.tokenizer = None elif self.custom_tokenizer: # If tokenizer is already a tokenizer object, custom_tokenizer is not compatible diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index 24dae94f0123..94de0a424f01 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -403,6 +403,10 @@ def _init_llm(self, chat_template: Optional[str] = None): else: self.use_harmony = (type(self.model_config).model_type == "gpt_oss") + self._ensure_post_processor_hook_supported( + self.use_harmony, + getattr(self.generator.args, "post_processor_hook", None)) + self.tool_call_id_type = "random" # default tool call id type is random if self.model_config is not None: # NOTE: Use the instance-level ``model_type`` (JSON-derived) here, not @@ -447,6 +451,22 @@ def _init_llm(self, chat_template: Optional[str] = None): self.perf_metrics = deque(maxlen=max_perf_metrics) self.perf_metrics_lock = asyncio.Lock() + @staticmethod + def _ensure_post_processor_hook_supported( + use_harmony: bool, post_processor_hook: Optional[str]) -> None: + """Reject ``--post_processor_hook`` combined with a harmony/gpt-oss model. + + The harmony output path is rebuilt from raw token ids, not detokenized + text, so the text-based hook cannot act there. + """ + if use_harmony and post_processor_hook: + raise ValueError( + "--post_processor_hook is not supported with harmony/gpt-oss models " + "in this version: the harmony output path is reconstructed from " + "raw token ids and would bypass the text-based hook. Disable the " + "hook or set DISABLE_HARMONY_ADAPTER=1 if the harmony path is " + "not needed.") + def _logit_bias_vocab_size(self) -> int: for config in (self.model_config, getattr(self.model_config, "text_config", None)): diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 64a40cebc21c..3ac865413622 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -902,6 +902,13 @@ def test_openai_tool_call(llm_root, llm_venv): str(test_root / "_test_openai_tool_call.py")]) +def test_openai_post_processor(llm_root, llm_venv): + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd( + ["-m", "pytest", + str(test_root / "_test_openai_post_processor.py")]) + + @pytest.mark.parametrize("sampler", ["torch_sampler", "trtllm_sampler"]) def test_openai_completions_with_logit_bias(llm_root, llm_venv, sampler: str): test_root = unittest_path() / "llmapi" / "apps" diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 3e61e02ee1f6..884aa1e080b5 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -106,6 +106,7 @@ l0_a10: - test_e2e.py::test_openai_misc_example[pytorch] - test_e2e.py::test_openai_reasoning[pytorch] - test_e2e.py::test_openai_tool_call + - test_e2e.py::test_openai_post_processor - test_e2e.py::test_openai_responses_entrypoint - test_e2e.py::test_openai_completions_example[pytorch] - test_e2e.py::test_openai_chat_example[pytorch] TIMEOUT (90) @@ -130,6 +131,8 @@ l0_a10: - unittest/executor/test_rpc.py - unittest/executor/test_ipc.py - unittest/executor/test_fatal_error_health_check.py + - unittest/executor/test_postprocessor_hook.py + - unittest/executor/test_proxy_postproc_terminate.py # trtllm-serve CPU-only - unittest/llmapi/apps/test_chat_utils.py - unittest/llmapi/apps/test_tool_parsers.py diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index a415fdebcb98..072affe92c19 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -64,6 +64,10 @@ methods: annotation: Optional[str] default: null status: prototype + post_processor_hook: + annotation: Optional[str] + default: null + status: prototype # reasoning reasoning_parser: annotation: Optional[str] diff --git a/tests/unittest/executor/test_postprocessor_hook.py b/tests/unittest/executor/test_postprocessor_hook.py new file mode 100644 index 000000000000..581453466aa5 --- /dev/null +++ b/tests/unittest/executor/test_postprocessor_hook.py @@ -0,0 +1,320 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Unit tests for the trtllm-serve post-processing hook.""" + +import pytest + +from tensorrt_llm.executor.postprocessor_hook import ( + PostProcessorHookChunk, + apply_post_processor_hook, + emit, + load_post_processor_hook, + suppress, + terminate, +) +from tensorrt_llm.executor.result import CompletionOutput + + +class _FakeResult: + """Minimal stand-in for a GenerationResult at the detok chokepoint.""" + + def __init__( + self, + outputs, + req_id=1, + done=False, + aborted=False, + has_abort=False, + streaming=True, + post_processor_hook=None, + ): + self.outputs = outputs + self.id = req_id + self._done = done + self._aborted = aborted + self._streaming = streaming + # Per-instance hook ownership: the detok read site reads + # this attribute rather than a process global. + self._post_processor_hook = post_processor_hook + self.abort_called = 0 + if has_abort: + self.abort = self._abort + + def _abort(self): + self.abort_called += 1 + self._aborted = True + + +def _make_output(text, last_text_len=0, index=0, token_ids=None, logprobs=None): + out = CompletionOutput(index=index, text=text) + out.token_ids = token_ids if token_ids is not None else [1, 2, 3] + out.logprobs = logprobs if logprobs is not None else [-0.1, -0.2, -0.3] + out._last_text_len = last_text_len + return out + + +def test_rewrite_streaming_diff_replaces_only_the_diff(): + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out]) + + def hook(chunk: PostProcessorHookChunk): + assert chunk.text_diff == " world" + assert chunk.text == "hello world" + return emit(chunk.text_diff.upper()) + + apply_post_processor_hook(hook, result, streaming=True) + + # The already-emitted prefix is preserved; only the diff is rewritten. + assert out.text == "hello WORLD" + assert out.text_diff == " WORLD" + + +def test_suppress_withholds_the_diff_keeping_prefix(): + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out]) + + apply_post_processor_hook(lambda c: suppress(), result, streaming=True) + + assert out.text == "hello" + assert out.text_diff == "" + + +def test_terminate_marks_aborted_and_done_with_reason(): + out = _make_output("safe bad", last_text_len=len("safe ")) + result = _FakeResult([out]) + + apply_post_processor_hook(lambda c: terminate("policy_violation"), result, streaming=True) + + assert result._aborted is True + assert result._done is True + assert out.finish_reason == "stop" + assert out.stop_reason == "policy_violation" + # The violating diff is withheld. + assert out.text == "safe " + + +def test_suppress_blanks_token_and_logprob_diffs(): + # mid-stream chunk with one new token/logprob beyond what was emitted. + out = _make_output("hello world", last_text_len=len("hello")) + out._last_token_ids_len = 2 + out._last_logprobs_len = 2 + assert out.token_ids_diff == [3] + assert out.logprobs_diff == [-0.3] + result = _FakeResult([out]) + + apply_post_processor_hook(lambda c: suppress(), result, streaming=True) + + assert out.text == "hello" + # The raw token/logprob channel is withheld too, not just the text. + assert out.token_ids_diff == [] + assert out.logprobs_diff == [] + + +def test_suppress_blanks_full_token_and_logprob_channels_non_streaming(): + """Non-streaming suppress must blank the full token/logprob channels. + + Non-streaming emits the full token_ids/logprobs (not the diff), so suppress + must truncate the full channels — otherwise a detokenize=False completion + (token_ids) or any logprobs response leaks the withheld output. + """ + out = _make_output("the full answer", last_text_len=0) + # Non-streaming single response: the diff watermark stays at 0. + assert out.token_ids == [1, 2, 3] + assert out.logprobs == [-0.1, -0.2, -0.3] + result = _FakeResult([out], done=True) + + apply_post_processor_hook(lambda c: suppress(), result, streaming=False) + + assert out.text == "" + # The full channels the non-streaming formatter reads are emptied, not just + # the diff view. + assert out.token_ids == [] + assert out.logprobs == [] + + +def test_terminate_calls_abort_when_available_and_blanks_token_channel(): + out = _make_output("safe bad", last_text_len=len("safe ")) + out._last_token_ids_len = 1 + result = _FakeResult([out], has_abort=True) + + apply_post_processor_hook(lambda c: terminate("policy"), result, streaming=True) + + assert result._aborted is True + assert result._done is True + assert result.abort_called == 1 + assert out.token_ids_diff == [] + assert out.finish_reason == "stop" + assert out.stop_reason == "policy" + + +def test_terminate_without_abort_attr_does_not_crash(): + out = _make_output("safe bad", last_text_len=len("safe ")) + result = _FakeResult([out], has_abort=False) + apply_post_processor_hook(lambda c: terminate("policy"), result, streaming=True) + assert result._aborted is True + assert result._done is True + + +def test_hook_exception_fails_closed_and_reraises(): + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out]) + + def boom(chunk): + raise RuntimeError("hook bug") + + # Fail-closed: the exception propagates so the request errors rather than + # serving the un-vetted chunk. The text is not advanced past the prefix. + with pytest.raises(RuntimeError, match="hook bug"): + apply_post_processor_hook(boom, result, streaming=True) + + assert out.text == "hello world" + + +def test_non_streaming_rewrites_full_text(): + # Non-stream single response: _last_text_len == 0, so diff == full text. + out = _make_output("the full answer", last_text_len=0) + result = _FakeResult([out], done=True) + + def hook(chunk: PostProcessorHookChunk): + assert chunk.text_diff == chunk.text == "the full answer" + return emit("REDACTED") + + apply_post_processor_hook(hook, result, streaming=False) + + assert out.text == "REDACTED" + + +def test_passthrough_emit_is_idempotent(): + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out]) + + apply_post_processor_hook(lambda c: emit(c.text_diff), result, streaming=True) + + assert out.text == "hello world" + assert out.text_diff == " world" + + +def test_per_request_state_is_keyed_by_request_id(): + """The hook owns its per-request state; trtllm only passes request_id.""" + + class Counter: + def __init__(self): + self.state = {} + + def __call__(self, chunk: PostProcessorHookChunk): + n = self.state.get(chunk.request_id, 0) + 1 + self.state[chunk.request_id] = n + if chunk.is_final: + self.state.pop(chunk.request_id, None) + return emit(f"{chunk.text_diff}#{n}") + + hook = Counter() + r1a = _FakeResult([_make_output("a", 0)], req_id=1) + r2 = _FakeResult([_make_output("b", 0)], req_id=2) + r1b = _FakeResult([_make_output("ac", 1)], req_id=1, done=True) + + apply_post_processor_hook(hook, r1a, streaming=True) + apply_post_processor_hook(hook, r2, streaming=True) + apply_post_processor_hook(hook, r1b, streaming=True) + + # Request 1 counts 1 then 2, independently of request 2 (which counts 1). + assert r1a.outputs[0].text == "a#1" + assert r2.outputs[0].text == "b#1" + assert r1b.outputs[0].text == "ac#2" + # State released on the final chunk. + assert 1 not in hook.state + + +def test_unknown_verdict_action_rejected_at_construction(): + """An unknown action cannot be smuggled: the verdict rejects it on build.""" + from tensorrt_llm.executor.postprocessor_hook import PostProcessorHookVerdict + + with pytest.raises(ValueError): + PostProcessorHookVerdict(action="bogus") + + +def test_unknown_action_fails_closed_through_apply(): + """A hook constructing a bad verdict fails the request closed (re-raised).""" + from tensorrt_llm.executor.postprocessor_hook import PostProcessorHookVerdict + + out = _make_output("x", 0) + result = _FakeResult([out]) + with pytest.raises(ValueError): + apply_post_processor_hook( + lambda c: PostProcessorHookVerdict(action="bogus"), result, streaming=True + ) + + +def test_loader_resolves_import_path(): + # Any importable, no-arg-constructible callable instance works as a smoke + # test. MagicMock instances are callable, unlike e.g. OrderedDict. + hook = load_post_processor_hook("unittest.mock.MagicMock") + assert hook is not None + + +def test_loader_raises_on_bad_path(): + with pytest.raises(ValueError, match="Failed to load post-processor hook"): + load_post_processor_hook("no.such.module.Nope") + + +def test_loader_raises_on_non_callable(): + """A non-callable resolved instance fails at load time, not per-chunk. + + OrderedDict instances are not callable, so they exercise this path. + """ + with pytest.raises(ValueError, match="not callable"): + load_post_processor_hook("collections.OrderedDict") + + +def test_loader_builds_independent_instances(): + """Each owner builds its own instance (no shared process-global cache). + + This is the core of per-instance ownership: two owners + loading the same import path get distinct instances, so their per-request + state never collides. + """ + a = load_post_processor_hook("unittest.mock.MagicMock") + b = load_post_processor_hook("unittest.mock.MagicMock") + assert a is not b + + +def test_apply_method_reads_hook_from_instance_attribute(): + """The detok read site applies the hook owned by the result instance.""" + from tensorrt_llm.executor.result import DetokenizedGenerationResultBase + + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out], post_processor_hook=lambda c: emit(c.text_diff.upper())) + + # Call the real read site against the per-instance attribute. + DetokenizedGenerationResultBase._apply_post_processor_hook(result) + + assert out.text == "hello WORLD" + + +def test_apply_method_is_noop_when_instance_has_no_hook(): + """With no hook on the instance, the chunk passes through untouched.""" + from tensorrt_llm.executor.result import DetokenizedGenerationResultBase + + out = _make_output("hello world", last_text_len=len("hello")) + result = _FakeResult([out], post_processor_hook=None) + + DetokenizedGenerationResultBase._apply_post_processor_hook(result) + + assert out.text == "hello world" + + +def test_harmony_model_rejects_post_processor_hook(): + """A harmony/gpt-oss model + post_processor_hook must fail fast. + + The harmony output path is rebuilt from raw token ids and would bypass the + text-based hook, so the server refuses the combination at startup. + """ + from tensorrt_llm.serve.openai_server import OpenAIServer + + guard = OpenAIServer._ensure_post_processor_hook_supported + with pytest.raises(ValueError, match="not supported with harmony"): + guard(use_harmony=True, post_processor_hook="my_pkg.guardrail.Hook") + # Every other combination is allowed. + guard(use_harmony=False, post_processor_hook="my_pkg.guardrail.Hook") + guard(use_harmony=True, post_processor_hook=None) + guard(use_harmony=False, post_processor_hook=None) diff --git a/tests/unittest/executor/test_proxy_postproc_terminate.py b/tests/unittest/executor/test_proxy_postproc_terminate.py new file mode 100644 index 000000000000..00b6883c8734 --- /dev/null +++ b/tests/unittest/executor/test_proxy_postproc_terminate.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Regression tests for proxy late-response handling after a hook terminate. + +When a post-processor hook returns ``terminate`` the result is +marked done and popped from the +proxy's ``_results`` map, but the engine may still have in-flight responses for +the same ``client_id`` (abort is async). ``GenerationExecutorProxy.dispatch_result_task`` +must drop those late responses via ``self._results.get(client_id)`` rather than +``pop``-ing a missing key and crashing the dispatch thread with a ``KeyError``. +""" + +from types import SimpleNamespace + +from tensorrt_llm.executor.proxy import GenerationExecutorProxy + + +class _FakeResultQueue: + """Returns one queued item per ``get()`` call (mirrors the IPC queue).""" + + def __init__(self, items): + self._items = list(items) + + def get(self): + return self._items.pop(0) + + +class _RecordingQueue: + """Stand-in for a result's delivery queue (non-``_SyncQueue`` branch).""" + + def __init__(self): + self.delivered = [] + + def put(self, res): + self.delivered.append(res) + + +def _make_proxy(): + # Avoid GenerationExecutorProxy.__init__ (it spawns workers); we only + # exercise the pure dispatch logic with hand-set attributes. + return object.__new__(GenerationExecutorProxy) + + +def test_late_response_after_terminate_is_dropped_without_keyerror(): + """A response for an already-popped client_id must be dropped, not crash.""" + proxy = _make_proxy() + proxy._results = {} # terminate already finalized + popped this client_id + + late = SimpleNamespace(client_id=999, has_error=False, result=SimpleNamespace(is_final=True)) + proxy.result_queue = _FakeResultQueue([late]) + + # Must not raise KeyError; the dispatch loop stays alive (returns True). + assert proxy.dispatch_result_task() is True + + +def test_final_response_is_delivered_and_popped(): + """A live, final response is delivered and removed from ``_results``. + + This establishes the exact condition under which a later duplicate becomes + the 'late response' that the drop above guards against. + """ + proxy = _make_proxy() + queue = _RecordingQueue() + proxy._results = {7: SimpleNamespace(queue=queue)} + + resp = SimpleNamespace(client_id=7, has_error=False, result=SimpleNamespace(is_final=True)) + proxy.result_queue = _FakeResultQueue([resp]) + + assert proxy.dispatch_result_task() is True + assert queue.delivered == [resp] + assert 7 not in proxy._results # popped on is_final diff --git a/tests/unittest/llmapi/apps/_postproc_hook_samples.py b/tests/unittest/llmapi/apps/_postproc_hook_samples.py new file mode 100644 index 000000000000..53e02fce1d60 --- /dev/null +++ b/tests/unittest/llmapi/apps/_postproc_hook_samples.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Sample post-processing hooks for the trtllm-serve hook integration test. + +These are deliberately stateless and deterministic so the test can assert the +client-visible effect regardless of the (non-deterministic) model output. Each +class is a top-level, no-arg-constructible, importable callable so it can be +supplied to ``trtllm-serve --post_processor_hook`` and reconstructed by reference in +the post-processing worker process. +""" + +from tensorrt_llm.executor.postprocessor_hook import ( + PostProcessorHookChunk, + PostProcessorHookVerdict, + emit, + suppress, + terminate, +) + + +class UppercaseHook: + """Rewrite every chunk's text to upper case.""" + + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: + return emit(chunk.text_diff.upper()) + + +class SuppressHook: + """Withhold all output (every chunk is suppressed).""" + + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: + return suppress() + + +class TerminateHook: + """Terminate the stream immediately on the first chunk seen.""" + + def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: + return terminate("test_policy") diff --git a/tests/unittest/llmapi/apps/_test_openai_post_processor.py b/tests/unittest/llmapi/apps/_test_openai_post_processor.py new file mode 100644 index 000000000000..e4cd279727a4 --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_post_processor.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""End-to-end tests for the trtllm-serve post-processing hook. + +Launches a real ``trtllm-serve`` with ``--post_processor_hook`` pointing at one of +the sample hooks in ``_postproc_hook_samples`` and asserts the client-visible +effect (rewrite / suppress / terminate) across the chat and completions +endpoints, streaming and non-streaming, with the postproc worker pool both +disabled (in-proxy detok) and enabled (worker-process detok). + +The hooks are stateless and deterministic so assertions hold regardless of the +model's (non-deterministic) output: + - UppercaseHook: every chunk is upper-cased -> output == output.upper() + - SuppressHook: every chunk is withheld -> output == "" + - TerminateHook: first chunk terminates -> output == "", stops early +""" + +import os + +import openai +import pytest + +from ..test_llm import get_model_path +from .openai_server import RemoteOpenAIServer + +pytestmark = pytest.mark.threadleak(enabled=False) + +# Dotted import paths into the sample-hook module shipped alongside this test. +_HOOKS = { + "uppercase": "_postproc_hook_samples.UppercaseHook", + "suppress": "_postproc_hook_samples.SuppressHook", + "terminate": "_postproc_hook_samples.TerminateHook", +} + + +@pytest.fixture(scope="module") +def model_name(): + return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + + +@pytest.fixture(scope="module", params=[0, 2], ids=["disable_processpool", "enable_processpool"]) +def num_postprocess_workers(request): + return request.param + + +@pytest.fixture(scope="module", params=list(_HOOKS), ids=list(_HOOKS)) +def hook(request): + return request.param + + +@pytest.fixture(scope="module") +def server(model_name: str, num_postprocess_workers: int, hook: str): + model_path = get_model_path(model_name) + args = [ + "--backend", + "pytorch", + # co-exist with other servers + "--kv_cache_free_gpu_memory_fraction", + "0.2", + "--num_postprocess_workers", + f"{num_postprocess_workers}", + "--post_processor_hook", + _HOOKS[hook], + ] + # Make the sample-hook module importable by the server (and its postproc + # worker) subprocesses. + apps_dir = os.path.dirname(os.path.abspath(__file__)) + env = os.environ.copy() + env["PYTHONPATH"] = apps_dir + os.pathsep + env.get("PYTHONPATH", "") + with RemoteOpenAIServer(model_path, args, env=env) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer): + return server.get_client() + + +@pytest.fixture(scope="module") +def async_client(server: RemoteOpenAIServer): + return server.get_async_client() + + +def _assert_text_matches_hook(hook: str, text: str): + if hook == "uppercase": + assert text == text.upper(), f"expected upper-cased text, got {text!r}" + elif hook == "suppress": + assert text == "", f"expected suppressed (empty) text, got {text!r}" + elif hook == "terminate": + assert text == "", f"expected terminated (empty) text, got {text!r}" + else: + raise AssertionError(f"unknown hook {hook}") + + +def test_completions_non_streaming(client: openai.OpenAI, model_name: str, hook: str): + completion = client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=16, + temperature=0.0, + ) + text = completion.choices[0].text + _assert_text_matches_hook(hook, text) + if hook == "terminate": + assert completion.choices[0].finish_reason == "stop" + + +def test_completions_detokenize_false_does_not_bypass_hook( + client: openai.OpenAI, model_name: str, hook: str +): + """A server-side hook is a guardrail and must run even when the client sets + ``detokenize=false`` — that flag controls only the returned channel, not + whether the hook executes.""" + completion = client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=16, + temperature=0.0, + extra_body={"detokenize": False}, + ) + choice = completion.choices[0] + if hook == "terminate": + # The terminate hook fires on the first chunk regardless of the text + # channel, so the request stops early instead of running to max_tokens. + # If detokenize=false bypassed the hook, this would be "length". + assert choice.finish_reason == "stop" + # The withheld content must not leak through the token-id channel that + # detokenize=false returns. + assert not choice.token_ids + elif hook == "suppress": + # detokenize=false returns token_ids; suppress must withhold them too. + assert not choice.token_ids + + +@pytest.mark.asyncio(loop_scope="module") +async def test_completions_streaming(async_client: openai.AsyncOpenAI, model_name: str, hook: str): + stream = await async_client.completions.create( + model=model_name, + prompt="Hello, my name is", + max_tokens=16, + temperature=0.0, + stream=True, + ) + text = "" + finish_reason = None + async for chunk in stream: + token = chunk.choices[0].text + if token: + text += token + if chunk.choices[0].finish_reason is not None: + finish_reason = chunk.choices[0].finish_reason + _assert_text_matches_hook(hook, text) + if hook == "terminate": + # The hook (not an empty generation) must have stopped the stream. + assert finish_reason == "stop", ( + f"expected finish_reason 'stop' from terminate, got {finish_reason!r}" + ) + + +def test_chat_non_streaming(client: openai.OpenAI, model_name: str, hook: str): + chat = client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "Hello, tell me a short story."}], + max_tokens=16, + temperature=0.0, + ) + content = chat.choices[0].message.content or "" + _assert_text_matches_hook(hook, content) + if hook == "terminate": + assert chat.choices[0].finish_reason == "stop" + + +@pytest.mark.asyncio(loop_scope="module") +async def test_chat_streaming(async_client: openai.AsyncOpenAI, model_name: str, hook: str): + stream = await async_client.chat.completions.create( + model=model_name, + messages=[{"role": "user", "content": "Hello, tell me a short story."}], + max_tokens=16, + temperature=0.0, + stream=True, + ) + content = "" + finish_reason = None + async for chunk in stream: + delta = chunk.choices[0].delta.content + if delta: + content += delta + if chunk.choices[0].finish_reason is not None: + finish_reason = chunk.choices[0].finish_reason + _assert_text_matches_hook(hook, content) + if hook == "terminate": + # The hook (not an empty generation) must have stopped the stream. + assert finish_reason == "stop", ( + f"expected finish_reason 'stop' from terminate, got {finish_reason!r}" + ) diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index 5556f3bfbf04..ba66a033b83e 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -226,6 +226,20 @@ def test_decoding_type_eagle3_errors_on_tensorrt_backend(): TrtLlmArgs(model=llama_model_path, speculative_config=spec_cfg) +def test_post_processor_hook_rejected_with_skip_tokenizer_init(): + """post_processor_hook + skip_tokenizer_init must fail fast. + + The hook is a text-based guardrail; pairing it with skip_tokenizer_init (no + detokenized text) must be rejected rather than silently disabling it. + """ + with pytest.raises(ValidationError, match="skip_tokenizer_init"): + TorchLlmArgs(model="/tmp/dummy_model", + skip_tokenizer_init=True, + post_processor_hook="my_pkg.guardrail.Hook") + # skip_tokenizer_init alone (no hook) is still fine. + TorchLlmArgs(model="/tmp/dummy_model", skip_tokenizer_init=True) + + class TestModelDefaults: """Test suite for model-specific default overrides functionality."""