-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Expand file tree
/
Copy pathbase_worker.py
More file actions
1278 lines (1105 loc) · 58 KB
/
Copy pathbase_worker.py
File metadata and controls
1278 lines (1105 loc) · 58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import enum
import gc
import json
import os
import weakref
from pathlib import Path
from queue import Queue
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import psutil
import torch
from tensorrt_llm.logger import logger
from .._torch.pyexecutor.kv_cache_stats import append_kv_cache_iteration_stats
from .._torch.pyexecutor.llm_request import LlmResponse
from .._utils import (global_mpi_rank, global_mpi_size, mpi_comm, mpi_rank,
nvtx_range_debug)
from ..bindings import executor as tllm
from ..builder import ConfigEncoder, Engine, EngineConfig
from ..llmapi.llm_args import BaseLlmArgs, ExecutorMemoryType, PybindMirror
from ..llmapi.tokenizer import TokenizerBase
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import _SyncQueue, get_numa_aware_cpu_affinity, logger_debug
from ..lora_manager import LoraManager
from ..prompt_adapter_manager import PromptAdapterManager
from ..runtime import ModelConfig
from ..runtime.model_runner import _engine_config_to_model_config
from ..sampling_params import BatchedLogitsProcessor, SamplingParams
from .executor import GenerationExecutor, IterationResultQueue
from .ipc import FusedIpcQueue, IpcQueue
from .postproc_worker import (PostprocParams, PostprocWorker,
PostprocWorkerConfig)
from .request import GenerationRequest, LoRARequest, PromptAdapterRequest
from .result import (GenerationResult, LogProbsResult, ResponseWrapper,
compute_logprobs, get_metrics_dict)
from .utils import (ErrorResponse, IntraProcessQueue, RequestError,
is_llm_response)
if TYPE_CHECKING:
from ..disaggregated_params import DisaggregatedParams
__all__ = [
"BaseWorker",
"_init_hf_modules",
]
def _init_hf_modules():
"""Initialize cached HuggingFace modules for models with trust_remote_code=True.
This is safe to call multiple times (idempotent) and should be called:
1. At module import time (for main process and spawned subprocesses)
2. At worker_main entry (for forked processes or external MPI ranks)
References: https://github.com/vllm-project/vllm/pull/871
"""
try:
from transformers.dynamic_module_utils import init_hf_modules
init_hf_modules()
logger.debug("HF modules initialized")
except ImportError as e:
logger.warning(f"ImportError initializing HF modules: {e}")
except Exception as e:
logger.error(f"Exception initializing HF modules: {e}")
_init_hf_modules()
class BaseWorker(GenerationExecutor):
class WorkerExit(GeneratorExit):
pass
def __init__(
self,
engine: Union[Path, Engine],
executor_config: Optional[tllm.ExecutorConfig] = None,
batched_logits_processor: Optional[BatchedLogitsProcessor] = None,
postproc_worker_config: Optional[PostprocWorkerConfig] = None,
is_llm_executor: Optional[bool] = None,
hf_model_dir: Optional[Path] = None,
tokenizer: Optional[TokenizerBase] = None,
llm_args: Optional[BaseLlmArgs] = None,
) -> None:
postproc_config = postproc_worker_config or PostprocWorkerConfig()
super().__init__(
num_postprocess_workers=postproc_config.num_postprocess_workers,
postprocess_tokenizer_dir=postproc_config.postprocess_tokenizer_dir,
is_llm_executor=is_llm_executor,
)
# inputs
self._engine = engine
self._executor_config = executor_config
self._batched_logits_processor = batched_logits_processor
self._postproc_worker_config = postproc_worker_config
self._is_llm_executor = is_llm_executor
self._hf_model_dir = hf_model_dir
self._tokenizer = tokenizer
self.llm_args = llm_args
self.engine = None
self.result_queue: Optional[IpcQueue] = None
self.postproc_queues: Optional[List[IpcQueue]] = None
self.rank = mpi_rank()
self.global_rank = global_mpi_rank()
# mapping: client_id -> GenerationResult
self._results: Dict[int, GenerationResult] = {}
# mapping: client_id from Proxy -> request_id returned from runtime backend
self._client_id_to_request_id: Dict[int, int] = {}
self._await_response_helper = AwaitResponseHelper(weakref.proxy(self))
self._backend = None if llm_args is None else llm_args.backend
self._is_pytorch_backend = self._backend in ["pytorch", "_autodeploy"]
self._lora_config = llm_args.lora_config if self._is_pytorch_backend else None
self._resource_governor_queue = None
if global_mpi_size() > 1:
logger.set_rank(self.global_rank)
@property
def resource_governor_queue(self):
return self._resource_governor_queue
def _configure_affinity(self, device_id):
'''Probe and configure the CPU affinity of the worker based on NUMA topology.
Args:
device_id: The CUDA device ID to determine optimal CPU affinity.
Note:
If the process already has constrained affinity, a warning is logged.
Configuration is handled as follows:
TLLM_NUMA_AWARE_WORKER_AFFINITY = <unset>
-> Affinity is automatically configured if it is unconstrained,
and deleted if it is constrained externally by the user.
TLLM_NUMA_AWARE_WORKER_AFFINITY = 1
-> Affinity is unconditionally auto-configured.
TLLM_NUMA_AWARE_WORKER_AFFINITY = 0 or any other value
-> Affinity is unconditionally _not_ auto-configured.
'''
# Get the current affinity setting
pid = os.getpid()
process = psutil.Process(pid)
cpu_affinity = process.cpu_affinity()
all_cpus = list(range(psutil.cpu_count()))
constrained_affinity = (cpu_affinity != all_cpus)
numa_aware_affinity = os.environ.get("TLLM_NUMA_AWARE_WORKER_AFFINITY")
# If affinity is constrained but the user hasn't explicitly
# requested NUMA-aware affinity, remove the constraints.
if constrained_affinity:
logger.warning(
f"Worker process {pid} is affined to run on the following CPUs: "
f"{cpu_affinity} (subset of all logical CPUs). This may harm "
f"performance if set incorrectly.")
if numa_aware_affinity is None:
logger.warning(
f"Worker process {pid} has constrained CPU affinity "
f"but `TLLM_NUMA_AWARE_WORKER_AFFINITY` is not set. "
f"Removing CPU affinity constraints.")
process.cpu_affinity(all_cpus)
# If affinity is unconstrained and the user hasn't explicitly
# prohibited it or the user has explicitly requested it, choose the
# optimal affinity based upon the NUMA topology
if ((numa_aware_affinity is None and not constrained_affinity)
or (numa_aware_affinity == "1")):
process.cpu_affinity(get_numa_aware_cpu_affinity(device_id))
logger.info(
f"Worker process {pid} CPU affinity set to "
f"{process.cpu_affinity()} for optimal NUMA-aware scheduling.")
def _get_comm_ranks_device_id(self):
device_id = self.global_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
# Make sure C++ executor would use same devices/ranks as py_executor
global_rank = global_mpi_rank()
comm_ranks = mpi_comm().allgather(global_rank)
device_ids = mpi_comm().allgather(device_id)
self._configure_affinity(device_id)
return comm_ranks, device_ids
def setup_engine(self):
"""
Setup the engine for the worker.
"""
if isinstance(self._engine, list):
self._engine = self._engine[self.rank]
def _create_py_executor():
args = {}
assert hasattr(
self.llm_args, "backend"
), "llm_args should be with backend in _create_py_executor"
_ = self._get_comm_ranks_device_id()
if self._backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.py_executor_creator import \
create_py_executor
create_executor = create_py_executor
args["llm_args"] = self.llm_args
args["checkpoint_dir"] = self._hf_model_dir
args["tokenizer"] = self._tokenizer
elif self._backend == "_autodeploy":
from tensorrt_llm._torch.auto_deploy.llm_args import \
LlmArgs as ADLlmArgs
from tensorrt_llm._torch.auto_deploy.shim.ad_executor import \
create_autodeploy_executor
create_executor = create_autodeploy_executor
assert isinstance(self.llm_args, ADLlmArgs)
args["ad_config"] = self.llm_args
args["tokenizer"] = self._tokenizer
else:
raise ValueError(f"Unsupported backend config: {self._backend}")
if self._resource_governor_queue is not None:
args["resource_governor_queue"] = self._resource_governor_queue
# Define additional attributes that can be used later, such as in _deduce_max_tokens
self.mapping = self.llm_args.parallel_config.to_mapping()
self.checkpoint_loader = None
if self._backend == "pytorch":
from tensorrt_llm._torch.pyexecutor.model_loader import \
_construct_checkpoint_loader
self.checkpoint_loader = _construct_checkpoint_loader(
self.llm_args.backend,
self.llm_args.checkpoint_loader,
self.llm_args.checkpoint_format,
mx_config=self.llm_args.mx_config,
mx_model_name=self.llm_args.model,
)
self.max_seq_len = self.llm_args.max_seq_len
# creare_py_executor may change some fields of llm_args
_executor = create_executor(**args)
if _executor.max_seq_len is not None:
# max_seq_len might be updated by model engine as in create_py_executor
self.max_seq_len = _executor.max_seq_len
return _executor
def _create_engine(executor_config):
engine = self._engine
if executor_config is None:
executor_config = tllm.ExecutorConfig(1)
executor_config.logits_post_processor_config = tllm.LogitsPostProcessorConfig(
processor_batched=self._batched_logits_processor,
replicate=False)
comm_ranks, device_ids = self._get_comm_ranks_device_id()
executor_config.parallel_config = tllm.ParallelConfig(
participant_ids=comm_ranks, device_ids=device_ids)
if isinstance(engine, Engine):
return tllm.Executor(engine.engine,
json.dumps(engine.config.to_dict(),
cls=ConfigEncoder),
tllm.ModelType.DECODER_ONLY,
executor_config=executor_config,
managed_weights=engine.managed_weights)
assert not hasattr(executor_config, "backend")
return tllm.Executor(engine, tllm.ModelType.DECODER_ONLY,
executor_config)
self.engine = _create_py_executor(
) if self.llm_args is not None else _create_engine(
self._executor_config)
self._lora_manager: Optional[LoraManager] = None
self._prompt_adapter_manager: Optional[PromptAdapterManager] = None
self._runtime_model_config: Optional[ModelConfig] = None
if self.rank == 0 and isinstance(self.engine, tllm.Executor):
if isinstance(self.engine, Engine):
engine_config = self.engine.config
else:
engine_config = EngineConfig.from_json_file(
f"{self._engine}/config.json")
self._runtime_model_config = _engine_config_to_model_config(
engine_config)
if engine_config.build_config.plugin_config.lora_plugin:
# TODO(azuker): Passing peft cache manager to LoraManager is used for LoRA optimization
# (see LoraManager constructor docstring). Getting the peft cache manager from this
# point in the TRT flow is currently not supported (it's at the CPP
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
# optimization is not available in TRT-python flow.
self._lora_manager = LoraManager(
mapping=engine_config.pretrained_config.mapping,
model_config=self._runtime_model_config,
cpp_peft_cache_manager=None)
if engine_config.build_config.max_prompt_embedding_table_size > 0:
self._prompt_adapter_manager = PromptAdapterManager()
if self._backend == "pytorch" and self._lora_config is not None:
from tensorrt_llm._torch.pyexecutor.resource_manager import \
ResourceManagerType
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
ResourceManagerType.PEFT_CACHE_MANAGER)
self._lora_manager = peft_cache_manager.get_lora_manager()
lora_model_config = self.engine.model_engine.lora_model_config
assert lora_model_config is not None
self._lora_model_config = lora_model_config
def await_responses(self, timeout: Optional[float] = None) -> list:
return self.engine.await_responses(timeout=datetime.timedelta(
seconds=timeout) if timeout is not None else None)
def fetch_stats(self) -> list:
if isinstance(self.engine, tllm.Executor):
iter_stats = self.engine.get_latest_iteration_stats()
#TODO: Support req stats with TRT engine
# This would require ensuring iter and req stats have same size
return [(iter_stat, None, None) for iter_stat in iter_stats]
else:
return self.engine.get_latest_iteration_stats()
def fetch_kv_cache_events(self) -> list:
if isinstance(self.engine, tllm.Executor):
return self.engine.get_latest_kv_cache_events()
else:
return self.engine.get_latest_kv_cache_events()
def set_result_queue(self, queue):
"""In multi-gpu mode, result_queue will be set here to communicate between the proxy and the worker 0 process."""
assert self.postproc_queues is None
self.result_queue = queue
def set_postproc_queues(self, queues: List["IpcQueue"]):
""" Set the IPC queues for feeding post-processing processes. """
assert self.result_queue is None
self.postproc_queues = queues
def _set_iteration_result_queue(self, it_result_queue: IterationResultQueue,
queue: Union[Queue, FusedIpcQueue,
IntraProcessQueue]):
assert not it_result_queue.is_initialized, "Iteration result queue should not already be initialized."
it_result_queue.is_initialized = True
it_result_queue.queue = queue
it_result_queue.aqueue = None
def return_queue(self, client_id: int):
"""Return the queue used to deliver responses for ``client_id``.
If a centralized result queue is registered (used for communication
with the proxy) send the message there. Otherwise, push the result
directly in the GenerationResult queue.
"""
if self.result_queue is not None:
return self.result_queue
return self._results[client_id].queue
def abort_request(self, client_id: int) -> None:
# NOTE: the request_id is the request_id generated by cpp runtime, not the client_id
if self.engine.can_enqueue_requests():
request_id = self._client_id_to_request_id.get(client_id, None)
if request_id is None:
logger.warning(
f"Request of client_id {client_id} is finished, cannot abort it."
)
return
self.engine.cancel_request(request_id)
def _engine_response_callback(self, response: tllm.Response):
return response
def _has_background_error(self) -> bool:
return not self._error_queue.empty()
def _create_error_response(self, response: tllm.Response) -> ErrorResponse:
bck_error = self._error_queue.get_nowait()
assert isinstance(bck_error, Exception)
return ErrorResponse(response.client_id, str(bck_error),
response.request_id)
def start(self):
raise NotImplementedError(
"start method is not implemented in BaseWorker")
def _load_lora_adapter(self, lora_request: LoRARequest) -> bool:
"""Returns True if the adapter was loaded by this call, False if it was already loaded"""
adapter_id = str(lora_request.adapter_id)
newly_loaded_uids = self._lora_manager.load_from_ckpt(
[lora_request.path],
model_config=self._runtime_model_config if
self._runtime_model_config is not None else self._lora_model_config,
uids=[adapter_id],
ckpt_source=lora_request.ckpt_source)
return adapter_id in newly_loaded_uids
def _load_prompt_adapter(self,
prompt_adapter_request: PromptAdapterRequest):
self._prompt_adapter_manager.load_from_ckpt(
[prompt_adapter_request.local_path],
model_config=self._runtime_model_config,
uids=[str(prompt_adapter_request.adapter_id)])
def _enqueue_request(self,
request: GenerationRequest,
result_wait_queue=None) -> int:
assert request.id is not None
py_lora_path = None
if self._lora_manager is not None and request.lora_request is not None:
try:
if self._is_pytorch_backend:
# PyTorch backend: don't embed weights in the request.
# Each rank loads independently from disk via py_lora_path
# in PeftCacheManager.add_request_peft().
# Pre-load on rank 0 to warm the LoRA manager cache so that
# add_request_peft finds the adapter already loaded.
self._load_lora_adapter(request.lora_request)
uid = str(request.lora_request.adapter_id)
lora_config = tllm.LoraConfig(
task_id=request.lora_request.adapter_id,
weights=None,
config=self._lora_manager.cpp_lora_config[uid])
else:
adapter_in_cache = self._lora_manager.is_adapter_in_cpu_cache(
request.lora_request.adapter_id)
self._load_lora_adapter(request.lora_request)
uid = str(request.lora_request.adapter_id)
lora_config = tllm.LoraConfig(
task_id=request.lora_request.adapter_id,
weights=self._lora_manager.cpp_lora_weights[uid]
if not adapter_in_cache else None,
config=self._lora_manager.cpp_lora_config[uid])
py_lora_path = request.lora_request.lora_path
except Exception as e:
raise RequestError(f"Failed to load LoRA adapter: {e}") from e
else:
lora_config = None
prompt_token_ids = list(request.prompt_token_ids)
prompt_tuning_config = None
if request.prompt_adapter_request is not None:
self._load_prompt_adapter(request.prompt_adapter_request)
uid = str(request.prompt_adapter_request.adapter_id)
prompt_tuning_config = tllm.PromptTuningConfig(
self._prompt_adapter_manager.uid_to_weights[uid])
vocab_size = self._runtime_model_config.vocab_size
pa_length = prompt_tuning_config.embedding_table.size(0)
prompt_token_ids = list(range(
vocab_size, vocab_size + pa_length)) + prompt_token_ids
# MULTIMODAL
# NOTE: Since, we only support PyTorch backend for multimodal, we will send multimodal_data through the 'py_multimodal_data' field
# except `multimodal_input` as it needs to go through the C++ runtime.
multimodal_input = None
if request.multimodal_params is not None and request.multimodal_params.has_content(
):
if request.multimodal_params.multimodal_input is not None:
multimodal_input = request.multimodal_params.multimodal_input.to_binding(
tllm)
# NOTE: Setting to None here to avoid sending multimodal_input again through the 'py_multimodal_data' field
request.multimodal_params.multimodal_input = None
context_phase_params = None
request_type = tllm.RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION
disagg_request_id = 0
if request.disaggregated_params is not None:
assert (
not self._is_pytorch_backend
or self.engine.kv_cache_transceiver is not None
or request.disaggregated_params.request_type
== "context_and_generation"
), "kv_cache_transceiver is disabled, please set 'cache_transceiver_config: backend:<backend_type>` in config file for disaggregated serving"
request_type = request.disaggregated_params.get_request_type()
disagg_request_id = request.disaggregated_params.disagg_request_id
if request_type == tllm.RequestType.REQUEST_TYPE_GENERATION_ONLY:
context_phase_params = request.disaggregated_params.get_context_phase_params(
)
assert request.id is not None
def _deduce_max_tokens(request: GenerationRequest,
executor_config: tllm.ExecutorConfig,
llm_args: Optional[BaseLlmArgs] = None) -> int:
# deduce max_tokens when it's not set by user
max_tokens = request.sampling_params.max_tokens
query_token_len = len(
request.query_token_ids) if request.query_token_ids else 0
cp_size = 1
max_seq_len = None
if llm_args is not None:
# deduce max_tokens by llm args
assert executor_config is None, "An empty executor_config in _deduce_max_tokens is expected when LLM arguments are defined."
if hasattr(self,
"mapping") and self.mapping.cp_size is not None:
cp_size = self.mapping.cp_size
max_seq_len = getattr(self, "max_seq_len", None)
else:
# deduce max_tokens by executor config
if hasattr(executor_config, "mapping"
) and executor_config.mapping.cp_size is not None:
cp_size = executor_config.mapping.cp_size
max_seq_len = getattr(executor_config, "max_seq_len", None)
if max_seq_len is None:
logger.warning("`default_max_tokens` cannot be deduced")
if max_tokens is None:
raise ValueError(
"`max_tokens` must be set when `default_max_tokens` cannot be deduced"
)
else:
# use max_tokens if can't deduce default_max_tokens
return max_tokens
if executor_config is not None:
assert (
len(prompt_token_ids) <= executor_config.max_seq_len
), f"`prompt_token_ids` length ({len(prompt_token_ids)}) is greater than `max_seq_len` ({executor_config.max_seq_len})"
splited_prompt_len = int(len(prompt_token_ids) / cp_size)
default_max_tokens = max_seq_len - splited_prompt_len - query_token_len
if default_max_tokens <= 0:
# Raise error on `default_max_tokens` not enough, since max_tokens should be less than `default_max_tokens``
raise ValueError(
f"`default_max_tokens` ({default_max_tokens}) must be greater than 0, "
f"`default_max_tokens` ({default_max_tokens}) = max_seq_len ({max_seq_len})"
f" - `splited_prompt_len` ({splited_prompt_len}) - `query_token_len` ({query_token_len})"
)
# default_max_tokens is the biggest available value
if max_tokens is None:
return default_max_tokens
elif max_tokens > default_max_tokens and default_max_tokens > 0:
logger.warning(
f"User-specified `max_tokens` ({max_tokens}) is greater than deduced "
f"`default_max_tokens` ({default_max_tokens}), using default_max_tokens instead."
)
return default_max_tokens
elif max_tokens <= 0:
raise ValueError(
f"`max_tokens` ({max_tokens}) must be greater than 0")
else:
return max_tokens
try:
executor_request = tllm.Request(
client_id=request.id,
input_token_ids=prompt_token_ids,
max_tokens=_deduce_max_tokens(
request,
self._executor_config if not self.llm_args else None,
self.llm_args),
streaming=request.streaming,
sampling_config=request.sampling_params._get_sampling_config(),
end_id=-1 if request.sampling_params.ignore_eos else
request.sampling_params.end_id,
pad_id=request.sampling_params.pad_id,
output_config=request.sampling_params._get_output_config(
is_pytorch_backend=self._is_pytorch_backend),
# Beam search enforces return_all_generated_tokens=True regardless of the passed value
return_all_generated_tokens=False,
# convert python config into pybind config
lookahead_config=PybindMirror.maybe_to_pybind(
request.sampling_params.lookahead_config),
guided_decoding_params=request.sampling_params.
_get_guided_decoding_params(),
bad_words=request.sampling_params._get_bad_words(),
stop_words=[] if request.sampling_params.ignore_eos else
request.sampling_params._get_stop_words(),
embedding_bias=request.sampling_params.embedding_bias,
lora_config=lora_config,
prompt_tuning_config=prompt_tuning_config,
multimodal_input=multimodal_input,
# NOTE: `multimodal_embedding` and `mrope_config` will be in MultimodalParams.multimodal_data. And this will be handled below by `py_multimodal_data`.
multimodal_embedding=None,
mrope_config=None,
logits_post_processor_name=(
tllm.Request.BATCHED_POST_PROCESSOR_NAME
if request.sampling_params.apply_batched_logits_processor
else None),
logits_post_processor=None if self._is_pytorch_backend else
request.sampling_params.logits_processor,
kv_cache_retention_config=request.kv_cache_retention_config,
context_phase_params=context_phase_params,
encoder_input_token_ids=request.encoder_input_token_ids,
type=request_type,
disagg_request_id=disagg_request_id,
cache_salt=request.cache_salt,
priority=request.priority)
executor_request.py_original_end_id = request.sampling_params.end_id
executor_request.py_num_logprobs = request.sampling_params.logprobs
executor_request.py_lora_path = py_lora_path
executor_request.py_logprobs_mode = request.sampling_params.logprobs_mode
executor_request.py_logprobs_simple_format = (
request.sampling_params.logprobs_simple_format)
# here we add executor_request.py_disaggregated_params= request.disaggregated_params for python cache transceiver
if self._is_pytorch_backend and request.disaggregated_params is not None:
executor_request.py_disaggregated_params = request.disaggregated_params
if self._is_pytorch_backend and request.multimodal_params is not None:
if request.multimodal_params.multimodal_data is not None:
# Resolve SharedTensorContainer dicts inside multimodal_data, including
# E/P handoff embedding handles parked under "multimodal_embedding".
request.multimodal_params.to_tensor("multimodal_data")
executor_request.py_multimodal_data = request.multimodal_params.multimodal_data
if self._is_pytorch_backend and request.sampling_params.logits_processor:
# For PyTorch backend, we attach logits processors as a dynamic Python attribute
# instead of using the C++ binding, since the latter will cause PyCapsule pickling issues.
lp = request.sampling_params.logits_processor
executor_request.py_logits_post_processors = lp if isinstance(
lp, list) else [lp]
executor_request.py_scheduling_params = None
if self._is_pytorch_backend and request.scheduling_params is not None:
executor_request.py_scheduling_params = request.scheduling_params
if request.arrival_time is not None:
executor_request.py_arrival_time = request.arrival_time
if request.query_token_ids is not None:
# pytorch star attention workflow
# a workaround to avoid public interface update
if self._is_pytorch_backend and result_wait_queue is not None:
req_id = self.engine.enqueue_request(
executor_request,
request.query_token_ids,
result_wait_queue=result_wait_queue)
else:
req_id = self.engine.enqueue_request(
executor_request, request.query_token_ids)
else:
if self._is_pytorch_backend and result_wait_queue is not None:
req_id = self.engine.enqueue_request(
executor_request, result_wait_queue=result_wait_queue)
else:
req_id = self.engine.enqueue_request(executor_request)
return req_id
except Exception as e:
raise RequestError(str(e)) from e
def submit(self, request: GenerationRequest) -> GenerationResult:
""" Low-level API to the executor. Return a "future" GenerationResult which can be waited. """
self.start()
if self.rank != 0:
raise RuntimeError(
"Only rank 0 can submit requests.\n"
"To fix this, ensure that the llm.generate(...) method is "
"guarded with the `if __name__ == '__main__':` block.")
client_id = request.id if request.id is not None else self._get_next_client_id(
)
if request.id is None:
request.set_id(client_id)
logprob_params = self._get_logprob_params(request)
result = GenerationResult(
request,
background_error_handler=self._handle_background_error,
executor=self,
disaggregated_params=request.disaggregated_params,
logprob_params=logprob_params)
self._results[client_id] = result
request_id = self._enqueue_request(request)
# request_id returned from backend is necessary for the abort_request method.
self._client_id_to_request_id[client_id] = request_id
self._handle_background_error()
return result
def _check_sleep_wakeup_preconditions(self, method: str) -> None:
"""Validate preconditions shared by sleep() and wakeup().
Args:
method: Name of the calling method (``"sleep"`` or ``"wakeup"``)
used in error messages.
Raises:
ValueError: If the backend is not ``"pytorch"`` or
``sleep_config`` is not set.
NotImplementedError: If ``parallel_config.world_size > 1``.
"""
# _autodeploy is intentionally excluded: its allocations are not tagged
# under sleep_config VMM scopes, so release_with_tag would silently
# no-op instead of actually freeing GPU memory. Use _backend directly
# rather than _is_pytorch_backend, which also covers _autodeploy.
if self._backend != "pytorch":
raise ValueError(
f"{method}() is only available for the PyTorch (TorchLLM) "
"backend.")
if self.llm_args is None or self.llm_args.sleep_config is None:
raise ValueError(
"Sleep feature is not enabled, please set sleep_config in "
"the LLM arguments.")
# Non-rank-0 processes block on their local control_action_done
# threading.Event with no Python caller to release it — deadlock.
if self.llm_args.parallel_config.world_size > 1:
raise NotImplementedError(
f"{method}() requires parallel_config.world_size == 1; "
"use the Ray executor for multi-rank deployments.")
def sleep(self, sleep_tags: List[str]) -> None:
"""Release GPU virtual memory for the specified memory type tags.
Single-rank (``world_size == 1``) only. Uses
``PyExecutor.control_action()`` to drain in-flight requests and pause
the event loop before calling ``release_with_tag()``, matching the
``@control_action_decorator`` behaviour used in Ray.
Only allocations backed by virtual memory (VMM) and registered under
the active :class:`~tensorrt_llm.llmapi.llm_args.SleepConfig` are
released. Components using alternative memory management (e.g.
``LoadFormat.GMS``-managed weights) are not VMM-tagged and will be
silently skipped by ``release_with_tag``.
Args:
sleep_tags: List of
:class:`~tensorrt_llm.llmapi.llm_args.ExecutorMemoryType`
value strings (e.g. ``["kv_cache"]``).
Returns:
None. The call is synchronous; when it returns all requested
VMM-tagged allocations have been released and the event loop
has been resumed.
Raises:
ValueError: If the backend is not ``"pytorch"`` or
``sleep_config`` is not set.
NotImplementedError: If ``parallel_config.world_size > 1``.
"""
self._check_sleep_wakeup_preconditions("sleep")
from tensorrt_llm._torch.virtual_memory import release_with_tag
tags = [ExecutorMemoryType(tag) for tag in sleep_tags]
logger.info(f"Sleep: {tags}")
with self.engine.control_action():
torch.cuda.synchronize()
release_with_tag(*tags)
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
def wakeup(self, wakeup_tags: List[str]) -> None:
"""Materialize GPU virtual memory for the specified memory type tags.
Single-rank (``world_size == 1``) only. See :meth:`sleep` for
details on VMM scope restrictions and backend prerequisites.
Args:
wakeup_tags: List of
:class:`~tensorrt_llm.llmapi.llm_args.ExecutorMemoryType`
value strings (e.g. ``["kv_cache"]``).
Returns:
None. The call is synchronous; when it returns all requested
VMM-tagged allocations have been materialized and the event loop
has been resumed.
Raises:
ValueError: If the backend is not ``"pytorch"`` or
``sleep_config`` is not set.
NotImplementedError: If ``parallel_config.world_size > 1``.
"""
self._check_sleep_wakeup_preconditions("wakeup")
from tensorrt_llm._torch.virtual_memory import materialize_with_tag
tags = [ExecutorMemoryType(tag) for tag in wakeup_tags]
logger.info(f"Wakeup: {tags}")
with self.engine.control_action():
torch.cuda.synchronize()
materialize_with_tag(*tags)
torch.cuda.synchronize()
def reset_prefix_cache(self) -> None:
"""Invalidate local KV prefix-cache reuse state on PyTorch engines."""
engine = self.engine
if engine is None or not hasattr(engine, "reset_prefix_cache"):
raise NotImplementedError(
"reset_prefix_cache() is only supported by the PyTorch backend."
)
with engine.control_action():
engine.reset_prefix_cache()
def shutdown(self):
if self.doing_shutdown:
return
else:
self.doing_shutdown = True
if self.engine is not None and self.engine.can_enqueue_requests():
self.engine.shutdown()
self.engine = None
def get_disaggregated_params(self) -> dict:
if self.engine is None or self.engine.kv_cache_transceiver is None:
logger.warning("Engine or kv cache transceiver is not initialized")
return {}
return self.engine.kv_cache_transceiver.get_disaggregated_params()
@staticmethod
def _stats_serializer(stats) -> str:
# Per-rank path: stats is ("per_rank_dict", {..., "rank": N}).
# Already serialized on the producing rank via allgather — just emit.
if (isinstance(stats, tuple) and len(stats) == 2
and stats[0] == "per_rank_dict"):
return json.dumps(stats[1])
iteration_stats, req_stats = stats[0], stats[1]
kv_iter_stats = stats[2] if len(stats) > 2 else None
attention_dp_rank = stats[3] if len(stats) > 3 else None
# Newer slots — guarded with len() checks so historical 4-tuples and
# any external code still appending the legacy shape keep working.
host_step_time_ms = stats[4] if len(stats) > 4 else None
prev_device_step_time_ms = stats[5] if len(stats) > 5 else None
scheduler_mode = stats[6] if len(stats) > 6 else None
gpu_forward_time_ms = stats[7] if len(stats) > 7 else None
stats_dict = json.loads(iteration_stats.to_json_str())
# Always tag the row so Dynamo's adapter can read
# stat["attentionDpRank"] without a missing-key branch. Non-ADP stats
# default to rank 0; ADP stats carry the rank supplied by PyExecutor.
stats_dict["attentionDpRank"] = (0 if attention_dp_rank is None else
attention_dp_rank)
if req_stats is not None and len(req_stats) > 0:
stats_dict["requestStats"] = []
for req_stat in req_stats:
stats_dict["requestStats"].append(
json.loads(req_stat.to_json_str()))
append_kv_cache_iteration_stats(stats_dict, kv_iter_stats)
# Per-loop CPU wall captured by profile_step() — always a clean
# single-loop measurement, matching the log line's `host_step_time`.
# Prefer this over iterLatencyMS when you need absolute per-loop
# CPU cost, especially under the overlap scheduler where
# iterLatencyMS measures the batch's full lifecycle (~2 loops).
if host_step_time_ms is not None:
stats_dict["hostStepTimeMS"] = host_step_time_ms
# GPU forward time read via the ping-pong CUDA event pair in
# profile_step(). Note the "prev" in the name: under steady state
# the value lags its sibling host_step_time on the same record by
# one loop (the event-pair being read corresponds to the loop
# before the one host_step_time describes). See the ping-pong
# comment in PyExecutor._profiler for the design rationale.
if prev_device_step_time_ms is not None:
stats_dict["prevDeviceStepTimeMS"] = prev_device_step_time_ms
# Batch-matched GPU forward time measured from the CUDA events around
# this record's _forward_step. This is the preferred field for
# ForwardPassMetrics wall_time.
if gpu_forward_time_ms is not None:
stats_dict["gpuForwardTimeMS"] = gpu_forward_time_ms
# Scheduler mode for this record. "overlap" means iterLatencyMS
# spans ~2 loops (use hostStepTimeMS for clean per-loop cost);
# "non_overlap" means iterLatencyMS is itself the clean per-loop
# CPU wall. Set per-record so consumers do not need server config.
if scheduler_mode is not None:
stats_dict["schedulerMode"] = scheduler_mode
# Convert back to JSON string
return json.dumps(stats_dict)
# Define a Callable to serialize KV cache events
@staticmethod
def _kv_cache_events_serializer(events) -> str:
from .._utils import KVCacheEventSerializer
return json.dumps(KVCacheEventSerializer.serialize(events))
def _pop_result(self, client_id: int):
self._results.pop(client_id, None)
self._client_id_to_request_id.pop(client_id, None)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback) -> bool:
self.shutdown()
return exc_type is None or exc_type == self.WorkerExit
def __del__(self):
self.shutdown()
class AwaitResponseHelper:
''' Multiple-implementations for await_response for performance. '''
class HandlerKind(enum.Enum):
unknown = 0
single_process_worker = 1
ipc_batched = 2
def __init__(self, worker: "BaseWorker"):
# TODO: make worker weakref
self.worker = worker
self.handler_kind: AwaitResponseHelper.HandlerKind = AwaitResponseHelper.HandlerKind.unknown
self.enable_postprocprocess_parallel = self.worker.enable_postprocess_parallel
# The error responses when submit request failed will be put here
self.temp_error_responses = Queue()
def responses_handler(self, responses: List[tllm.Response]):
HandlerKind = AwaitResponseHelper.HandlerKind
if self.handler_kind is HandlerKind.unknown:
if not (self.worker.result_queue is not None
or self.worker.postproc_queues is not None):
logger_debug(f"creating await_response helper for Worker\n",
color="yellow")
# When ExecutorBindingWorker is used in the main process
# aka the single process mode
self.handler_kind = HandlerKind.single_process_worker
elif self.worker.result_queue is not None or self.worker.postproc_queues is not None:
# The ExecutorBindingProxy is used
logger_debug(f"creating await_response helper for IPC\n",
color="yellow")
self.handler_kind = HandlerKind.ipc_batched
else:
raise NotImplementedError
match self.handler_kind:
case HandlerKind.single_process_worker:
return self.handle_for_worker(responses)
case HandlerKind.ipc_batched:
return self.handle_for_ipc_batched(responses)
case _:
raise NotImplementedError
def __call__(self, timeout: Optional[float] = None) -> bool:
''' This method should be called by a ManagedThread. '''
timeout = timeout or 0.1
try:
responses = self.worker.engine.await_responses(
timeout=datetime.timedelta(seconds=timeout))
except Exception as e:
# Defensive: with id=None, PyExecutor.await_responses routes
# to _await_any_response, which does not raise on event-loop
# crash — it returns [] silently and we detect the crash
# via engine._event_loop_error after this block. But any
# unexpected exception out of await_responses (e.g. from a
# different engine implementation, or a future change to
# _await_any_response) is also a clear signal to broadcast
# and stop the thread.
return self._broadcast_event_loop_error(e)
# filter since The _engine_response_callback may return None
responses = list(
filter(
lambda _: _,
[self.worker._engine_response_callback(r) for r in responses]))
# append the error responses to the temp_error_responses
while not self.temp_error_responses.empty():
responses.append(self.temp_error_responses.get())
with nvtx_range_debug(f"await_response-{len(responses)}",
color="red",
category="Worker"):
self.responses_handler(responses)
# Even when await_responses returned normally (e.g. via
# _await_any_response, whose predicate already includes
# is_shutdown but does not raise), an event-loop crash leaves
# _event_loop_error stashed on the engine. Broadcast and stop the
# thread in that case too — see nvbug 6038228.
error = getattr(self.worker.engine, "_event_loop_error", None)
if error is not None:
return self._broadcast_event_loop_error(error)
return True
def _broadcast_event_loop_error(self, error: BaseException) -> bool:
"""Wake every pending ``GenerationResult`` after an event-loop crash.
Inject an ``ErrorResponse`` into every pending ``GenerationResult``
queue so callers parked in ``queue.get()`` / ``aqueue.get()``
(``LLM.generate``, ``generate_async`` + ``aresult``,
``trtllm-bench``) wake up with a meaningful error instead of
hanging when the PyExecutor event loop dies.
Returns ``False`` so the calling ``ManagedThread`` exits — there
is no point polling a dead engine.
Scope: single-process worker (the path that backs ``LLM.generate``
and the bench async client). The IPC / proxy path tracks pending
results on a different side of the boundary and would need a