-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Expand file tree
/
Copy pathpostprocessor_hook.py
More file actions
233 lines (195 loc) · 9.3 KB
/
Copy pathpostprocessor_hook.py
File metadata and controls
233 lines (195 loc) · 9.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""User-pluggable post-processing hook for ``trtllm-serve`` (TRTLLM-12622).
A user supplies a picklable, importable callable class via the
``--post_processor_hook`` import path. One instance is built per owner (the ``LLM``
for the in-proxy detok path, and each post-processing worker process when
enabled) and invoked once per output, per streaming chunk (plus a final call),
*after* detokenization and *before* the per-endpoint response formatter. The
hook owns its per-request state, keyed by ``chunk.request_id``.
Stdlib-only so it can be loaded in the post-processing worker process.
"""
import dataclasses
import enum
import importlib
import logging
from typing import List, Optional, Protocol, runtime_checkable
__all__ = [
"PostProcessorHookAction",
"PostProcessorHookChunk",
"PostProcessorHookVerdict",
"PostProcessorHook",
"emit",
"suppress",
"terminate",
"apply_post_processor_hook",
"load_post_processor_hook",
]
logger = logging.getLogger(__name__)
def load_post_processor_hook(import_path: str) -> "PostProcessorHook":
"""Build a post-processor hook instance from a dotted import path.
Mirrors ``tensorrt_llm.tokenizer.load_custom_tokenizer``: resolve
``module.path.ClassName``, import the module, instantiate it with no
arguments. Only ``import_path`` crosses a process boundary (never the
instance), so the class must be importable and picklable.
Raises:
ValueError: If the path cannot be resolved, imported, or instantiated.
"""
try:
module_path, class_name = import_path.rsplit(".", 1)
module = importlib.import_module(module_path)
hook_class = getattr(module, class_name)
hook = hook_class()
except (ValueError, ImportError, AttributeError, TypeError) as e:
raise ValueError(
f"Failed to load post-processor hook '{import_path}': {e}. "
"Expected format: 'module.path.ClassName' resolving to a "
"no-arg-constructible callable class."
) from e
if not callable(hook):
raise ValueError(
f"Failed to load post-processor hook '{import_path}': resolved "
f"object is not callable (got {type(hook).__name__}); expected an "
"instance implementing __call__(chunk)."
)
return hook
@dataclasses.dataclass
class PostProcessorHookChunk:
"""The payload handed to the post-processing hook for one output chunk.
Attributes:
request_id: Stable identifier for the request; the same value is passed
for every chunk of a given response, so the hook can key its own
per-request state on it.
output_index: Index of the output/beam within the request.
text_diff: Newly detokenized text produced by this chunk (streaming).
For non-streaming requests this equals ``text``.
text: Full accumulated detokenized text so far for this output.
token_ids_diff: Newly generated token ids for this chunk.
is_final: True on the terminating call for this output.
aborted: True if the request has been marked aborted in this process
(e.g. a prior ``terminate`` verdict, or an abort observed by the
detok process). Output-side observation only; do not rely on it to
detect every upstream client cancellation.
streaming: True for streaming requests.
"""
request_id: int
output_index: int
text_diff: str
text: str
token_ids_diff: List[int]
is_final: bool
aborted: bool
streaming: bool
class PostProcessorHookAction(str, enum.Enum):
"""The kind of decision a hook returns for one chunk."""
EMIT = "emit"
SUPPRESS = "suppress"
TERMINATE = "terminate"
@dataclasses.dataclass
class PostProcessorHookVerdict:
"""The hook's decision for one chunk.
Use the :func:`emit`, :func:`suppress`, and :func:`terminate` helpers rather
than constructing this directly.
"""
action: PostProcessorHookAction
text: str = ""
reason: Optional[str] = None
def __post_init__(self):
# Coerce/validate so a hook can never smuggle an unknown action.
self.action = PostProcessorHookAction(self.action)
def emit(text: str) -> PostProcessorHookVerdict:
"""Emit ``text`` for this chunk (use to rewrite/redact, or pass through)."""
return PostProcessorHookVerdict(action=PostProcessorHookAction.EMIT, text=text)
def suppress() -> PostProcessorHookVerdict:
"""Withhold this chunk entirely (no client-visible output)."""
return PostProcessorHookVerdict(action=PostProcessorHookAction.SUPPRESS)
def terminate(reason: str) -> PostProcessorHookVerdict:
"""Stop the stream for this request. ``reason`` is surfaced as stop_reason."""
return PostProcessorHookVerdict(action=PostProcessorHookAction.TERMINATE, reason=reason)
@runtime_checkable
class PostProcessorHook(Protocol):
"""The interface a user post-processor implements.
The instance is built once per owner (its ``__init__`` is the one-time
setup) and called once per output, per chunk. It owns any per-request state
and is responsible for releasing it on ``chunk.is_final``.
"""
def __call__(self, chunk: PostProcessorHookChunk) -> PostProcessorHookVerdict: ...
def _withhold_token_channel(output, streaming: bool) -> None:
"""Withhold the raw token-id / logprob channels alongside the blanked text.
Otherwise a suppressed/terminated output leaks via them (``token_ids`` on
``/v1/completions`` with ``detokenize=False``, ``logprobs`` on both
endpoints). Streaming emits per-chunk diffs, so advancing the diff watermark
empties this chunk; non-streaming emits the full lists, so truncate them back
to the already-emitted prefix, mirroring how the text is blanked.
"""
if streaming:
output._last_token_ids_len = len(output.token_ids)
if getattr(output, "logprobs", None) is not None:
output._last_logprobs_len = len(output.logprobs)
else:
output.token_ids = output.token_ids[: output._last_token_ids_len]
if getattr(output, "logprobs", None) is not None:
output.logprobs = output.logprobs[: output._last_logprobs_len]
def apply_post_processor_hook(hook: PostProcessorHook, result, streaming: bool) -> None:
"""Run ``hook`` over ``result.outputs`` in place at the detok chokepoint.
Applies each verdict by rewriting the chunk's text diff on the output
(preserving the already-emitted prefix), suppressing it, or terminating the
stream via the existing abort machinery.
A hook exception fails the request closed (re-raised), never serving the
un-vetted chunk: the in-proxy path surfaces it to the serving handler, the
worker path converts it to an ``ErrorResponse``. Both keep the server and
other requests alive (mirrors Triton's per-request fail-closed model).
"""
# ``is_final`` is request-level (``result._done``): for n>1 / beam it fires
# once when the whole request finishes. emit/suppress act per output, but a
# terminate cancels the whole request (all outputs) because the engine
# request is the unit of cancellation. Hooks needing per-sequence state
# should key on (request_id, output_index).
is_final = result._done
for output in result.outputs:
chunk = PostProcessorHookChunk(
request_id=result.id,
output_index=output.index,
text_diff=output.text_diff,
text=output.text,
token_ids_diff=list(output.token_ids_diff),
is_final=is_final,
aborted=result._aborted,
streaming=streaming,
)
try:
verdict = hook(chunk)
except Exception:
logger.exception(
"Post-processor hook failed for request %s; failing the request closed.",
result.id,
)
raise
prefix = output.text[: output._last_text_len]
if verdict.action is PostProcessorHookAction.EMIT:
output.text = prefix + verdict.text
elif verdict.action is PostProcessorHookAction.SUPPRESS:
output.text = prefix
_withhold_token_channel(output, streaming)
elif verdict.action is PostProcessorHookAction.TERMINATE:
output.text = prefix + verdict.text
_withhold_token_channel(output, streaming)
output.finish_reason = "stop"
output.stop_reason = verdict.reason
result._aborted = True
result._done = True
# Cancel the engine request to stop wasted generation (on the worker
# path the proxy does the actual cancel via should_abort). getattr
# guard is defensive; real results always define abort().
abort = getattr(result, "abort", None)
if callable(abort):
try:
abort()
except Exception:
logger.exception(
"Failed to abort request %s after terminate verdict.", result.id
)
else:
# Unreachable for hook-returned verdicts (validated in
# ``__post_init__``); guards an unhandled future enum member.
raise ValueError(f"Unhandled post-processor action: {verdict.action!r}")