-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Expand file tree
/
Copy pathpostproc_worker.py
More file actions
302 lines (264 loc) · 12.4 KB
/
Copy pathpostproc_worker.py
File metadata and controls
302 lines (264 loc) · 12.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
import asyncio
import traceback
from collections import deque
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, NamedTuple,
Optional, Union)
import zmq
from .._utils import nvtx_range_debug
from ..bindings import executor as tllm
from ..llmapi.tokenizer import TransformersTokenizer, load_hf_tokenizer
from ..llmapi.utils import print_traceback_on_error
from ..logger import logger
from ..sampling_params import SamplingParams
from .ipc import ZeroMqQueue
from .postprocessor_hook import load_post_processor_hook
from .utils import ErrorResponse, is_llm_response
if TYPE_CHECKING:
from ..disaggregated_params import DisaggregatedParams
from .result import (DetokenizedGenerationResultBase, GenerationResult,
GenerationResultBase, ResponseWrapper)
__all__ = [
"PostprocWorker",
"PostprocWorkerConfig",
]
@dataclass(kw_only=True)
class PostprocArgs:
first_iteration: bool = True
num_prompt_tokens: Optional[int] = None
tokenizer: Optional[TransformersTokenizer] = None
ctx_usage: Optional[Any] = None
@dataclass(kw_only=True)
class PostprocParams:
post_processor: Callable[["GenerationResultBase", PostprocArgs], Any] = None
postproc_args: PostprocArgs = None
@dataclass
class PostprocWorkerConfig:
''' The config for the postprocess worker. '''
num_postprocess_workers: int = 0
postprocess_tokenizer_dir: Optional[str] = None
# Dotted import path of the user post-processing hook, or
# None. NOTE: distinct from ``PostprocParams.post_processor``, which is the
# per-endpoint response *formatter* (a Callable), not this hook.
post_processor_hook: Optional[str] = None
@property
def enabled(self) -> bool:
return self.num_postprocess_workers > 0
class PostprocWorker:
'''
The worker to postprocess the responses from the executor's await_response.
'''
@dataclass
class Input:
rsp: Union["tllm.Response", "ResponseWrapper"]
# The information necessary for creating a GenerationResult in the first Input for each request
sampling_params: Optional[SamplingParams] = None
postproc_params: Optional[PostprocParams] = None
disaggregated_params: Optional["DisaggregatedParams"] = None
streaming: Optional[bool] = None
class Output(NamedTuple):
client_id: int
res: Any
is_final: bool
metrics: Optional[dict[str, float]] = None
request_perf_metrics: Any = None
disaggregated_params: Any = None
should_abort: bool = False
finish_reason: Optional[str] = None
num_generated_tokens: Optional[int] = None
def __init__(
self,
pull_pipe_addr: tuple[str, Optional[bytes]],
push_pipe_addr: tuple[str, Optional[bytes]],
tokenizer_dir: str,
record_creator: Callable[
["PostprocWorker.Input", TransformersTokenizer], Any],
post_processor_hook: Optional[str] = None,
):
'''
Args:
pull_pipe_addr (tuple[str, Optional[bytes]]): The address and HMAC key of the input IPC.
push_pipe_addr (tuple[str, Optional[bytes]]): The address and HMAC key of the output IPC.
tokenizer_dir (str): The directory to load tokenizer.
record_creator (Callable[["ResponsePostprocessWorker.Input"], Any]): A creator for creating a record for a request.
result_handler (Optional[Callable[[GenerationResultBase], Any]]): A callback handles the final result.
post_processor_hook (Optional[str]): Import path of the user post-processing hook; built once and threaded onto each record.
'''
self._records: Dict[int, GenerationResult] = {}
self._record_creator = record_creator
self._pull_pipe = ZeroMqQueue(address=pull_pipe_addr,
is_async=True,
is_server=False,
name="postprocess_pull_pipe")
self._push_pipe = ZeroMqQueue(address=push_pipe_addr,
is_async=True,
is_server=False,
socket_type=zmq.PUSH,
name="postprocess_push_pipe")
self._to_stop = asyncio.Event()
self._q = deque()
# Load the tokenizer and share in all records
self._tokenizer = load_hf_tokenizer(tokenizer_dir)
# Build the user post-processing hook once, like the
# tokenizer above; threaded onto each record in ``_handle_input``.
self._post_processor_hook = (
load_post_processor_hook(post_processor_hook)
if post_processor_hook else None)
@staticmethod
def default_record_creator(
inp: "PostprocWorker.Input", tokenizer: TransformersTokenizer
) -> "DetokenizedGenerationResultBase":
from .result import DetokenizedGenerationResultBase
assert inp.sampling_params is not None
return DetokenizedGenerationResultBase(
inp.rsp.client_id,
sampling_params=inp.sampling_params,
postproc_params=inp.postproc_params,
streaming=inp.streaming,
tokenizer=tokenizer)
async def _handle_input(
self, input: Union["PostprocWorker.Input", "ResponseWrapper"]
) -> [Any, Optional[dict[str, float]]]:
''' Handle a single response from await_response worker. '''
if input.rsp.result.context_logits is not None or \
input.rsp.result.generation_logits is not None:
raise ValueError(
"Context logits or generation logits are not supposed to be "
"sent to postprocessing workers.")
with nvtx_range_debug("handle_input",
color="yellow",
category="Postproc"):
req_id = input.rsp.client_id
if req_id not in self._records:
# TODO: support variant creation later
self._records[req_id] = self._record_creator(
input, self._tokenizer)
# Thread the hook onto the record here rather than
# via record_creator, so custom record_creators keep working.
self._records[
req_id]._post_processor_hook = self._post_processor_hook
if input.disaggregated_params is not None:
self._records[
req_id]._disaggregated_params = input.disaggregated_params
record = self._records[req_id]
record._handle_response(input.rsp) # inplace
# Left the result_handler determine the final output dtype.
# NOTE: This will change the CompletionOutput._postprocess_result
metrics_dict = record.metrics_dict
perf_metrics = None
disaggregated_params = None
if record.outputs:
perf_metrics = record.outputs[0].request_perf_metrics
disaggregated_params = record.outputs[0].disaggregated_params
if postproc_params := record.postproc_params:
result_handler, args = postproc_params.post_processor, postproc_params.postproc_args
args.tokenizer = self._tokenizer
out = result_handler(record, args)
else:
# This should only be called in streaming mode, and each time it
# produces a single output.
out = record.outputs[0]
# TODO: Keep only the diff token_ids and text in streaming mode when
# result_handler is not set
return out, metrics_dict, perf_metrics, disaggregated_params
async def _batched_put(self):
''' Batched IPC send. '''
async for batch in self._mainloop():
if batch is None:
# notify dispatch_result corountine to quit
await self._push_pipe.put_async(None)
break
assert isinstance(batch, list)
await self._push_pipe.put_async(batch)
async def _mainloop(self):
''' The loop for handle_response and keep producing outputs. '''
async def handle_single_input(inp: PostprocWorker.Input,
batch: List[PostprocWorker.Output]):
assert isinstance(
inp, PostprocWorker.Input
), f"Expect PostprocWorker.Input, got {type(inp)}."
client_id = inp.rsp.client_id
# ErrorResponse has no 'result' attribute; pass it through
# directly so the proxy handles it via its ErrorResponse path.
if isinstance(inp.rsp, ErrorResponse):
batch.append(inp.rsp)
self._records.pop(client_id, None)
return
try:
is_final = inp.rsp.result.is_final if is_llm_response(
inp.rsp) else True
res, metrics, perf_metrics, disaggregated_params = await self._handle_input(
inp)
record = self._records.get(client_id)
# A `terminate` verdict forces the record done;
# honor it so the stream stops and the record is popped without
# waiting for the engine's own is_final.
if record is not None and record._done:
is_final = True
should_abort = record._aborted if record else False
finish_reason = record.outputs[0].finish_reason if (
record and record.outputs
) else None # pass this through for _handle_response
num_generated_tokens = len(record.outputs[0].token_ids) if (
record and record.outputs) else None
batch.append(
PostprocWorker.Output(
client_id=client_id,
res=res,
is_final=is_final,
metrics=metrics,
request_perf_metrics=perf_metrics,
disaggregated_params=disaggregated_params,
should_abort=should_abort,
finish_reason=finish_reason,
num_generated_tokens=num_generated_tokens,
))
if is_final:
self._records.pop(client_id, None)
except Exception as e:
logger.error(
f"Postprocessing error for client {client_id}: {e}\n"
f"{traceback.format_exc()}")
batch.append(
ErrorResponse(
client_id=client_id,
error_msg=f"Postprocessing error: {e}",
request_id=getattr(inp.rsp, 'request_id', -1),
))
self._records.pop(client_id, None)
while not self._to_stop.is_set():
batch = []
inputs: Optional[List[PostprocWorker.Input]
| PostprocWorker.
Input] = await self._pull_pipe.get_async()
if not isinstance(inputs, list):
inputs = [inputs]
for inp in inputs:
if inp is None:
self._to_stop.set()
yield None
break
await handle_single_input(inp, batch)
yield batch
def start(self):
''' Start the workflow in the current thread. '''
async def main():
await asyncio.gather(self._batched_put())
try:
asyncio.run(main())
except Exception as e:
print(traceback.format_exc())
raise e
@print_traceback_on_error
def postproc_worker_main(feedin_ipc_addr: tuple[str, Optional[bytes]],
feedout_ipc_addr: tuple[str, Optional[bytes]],
tokenizer_dir: str,
record_creator: Callable,
post_processor_hook: Optional[str] = None):
# Pass the hook import path; PostprocWorker builds it once.
worker = PostprocWorker(feedin_ipc_addr,
feedout_ipc_addr,
tokenizer_dir=tokenizer_dir,
record_creator=record_creator,
post_processor_hook=post_processor_hook)
worker.start()