Skip to content

Commit 00218e5

Browse files
authored
[None][feat] Add llm.encode() fast path for encoder-only models (#12801)
Signed-off-by: tingyangk <tingyangk@nvidia.com>
1 parent 12af895 commit 00218e5

13 files changed

Lines changed: 1015 additions & 12 deletions

File tree

tensorrt_llm/_torch/auto_deploy/llm.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,15 @@ class DemoLLM(LLM):
181181
def __init__(self, **kwargs):
182182
self.args: LlmArgs = LlmArgs(**kwargs)
183183

184+
if self.args.encode_only:
185+
raise NotImplementedError(
186+
"encode_only=True is not supported by DemoLLM (AutoDeploy debug path). "
187+
"Use the standard LLM class with backend='pytorch' for encoder-only mode."
188+
)
189+
# set encode_only and encoder_executor to BaseLLM's default values
190+
self._encode_only = False
191+
self._encoder_executor = None
192+
184193
self.mpi_session = None
185194
self.runtime_context = None
186195

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Copyright (c) 2025-2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Any, Dict
16+
17+
import torch
18+
19+
from tensorrt_llm.logger import logger
20+
21+
22+
class EncoderExecutor:
23+
"""Executor for models using the encode-only path.
24+
25+
Primary path: batch_forward(inputs) — synchronous batch execution.
26+
Delegates to model_engine.encoder_forward() for all heavy lifting
27+
(pre-allocated buffers, attention metadata, torch.compile).
28+
29+
This executor has no background thread, no scheduler, no sampler,
30+
and no request queue. It runs entirely on the calling thread.
31+
"""
32+
33+
def __init__(self, model_engine, dist):
34+
self.model_engine = model_engine
35+
self.dist = dist
36+
37+
logger.info(
38+
"encode_only path enabled: using EncoderExecutor. "
39+
"Scheduler, sampler, KV cache, and generation-related parameters "
40+
"(disable_overlap_scheduler, max_tokens, temperature, etc.) "
41+
"are bypassed. Use llm.encode() for inference."
42+
)
43+
44+
def batch_forward(self, inputs: Dict[str, Any]) -> Dict[str, torch.Tensor]:
45+
"""Execute a pre-formed batch in one forward pass.
46+
47+
Args:
48+
inputs: Dict with 'input_ids' ([total_tokens]) and 'seq_lens'
49+
([batch_size]) required. Optional model-specific kwargs
50+
(token_type_ids, inputs_embeds, etc.) are passed through.
51+
52+
Returns:
53+
Dict with 'logits' tensor and any other model outputs.
54+
"""
55+
return self.model_engine.encoder_forward(inputs)
56+
57+
def shutdown(self):
58+
"""No background thread to stop — just release model engine resources."""
59+
del self.model_engine

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 86 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3705,6 +3705,82 @@ def _prepare_inputs(
37053705
num_accepted_tokens_device, req_id_to_old_request, resource_manager,
37063706
maybe_graph)
37073707

3708+
def _prepare_encoder_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
3709+
"""Prepare model-ready inputs dict for encode-only path.
3710+
3711+
Encoder equivalent of _prepare_tp_inputs + _preprocess_inputs.
3712+
Consumes raw inputs dict, copies to pre-allocated CUDA buffers,
3713+
sets up attention metadata, and returns model-ready dict.
3714+
3715+
Args:
3716+
inputs: Dict with required keys 'input_ids' ([total_tokens]) and
3717+
'seq_lens' ([batch_size]). Optional 'position_ids'
3718+
([total_tokens]). Any additional keys (token_type_ids,
3719+
inputs_embeds, etc.) are passed through to the model's
3720+
forward() via **kwargs.
3721+
"""
3722+
token_ids = inputs['input_ids']
3723+
seq_lens = inputs['seq_lens']
3724+
position_ids = inputs.get('position_ids')
3725+
num_tokens = token_ids.shape[0]
3726+
batch_size = seq_lens.shape[0]
3727+
3728+
assert num_tokens <= self.max_num_tokens, (
3729+
f"num_tokens ({num_tokens}) exceeds max_num_tokens "
3730+
f"({self.max_num_tokens}). Reduce batch size or sequence lengths.")
3731+
3732+
# 1. Copy to pre-allocated CUDA buffers
3733+
self.input_ids_cuda[:num_tokens].copy_(token_ids, non_blocking=True)
3734+
if position_ids is None:
3735+
# Auto-generate packed position IDs: [0..n1-1, 0..n2-1, ...]
3736+
position_ids = torch.cat(
3737+
[torch.arange(s, dtype=torch.int32) for s in seq_lens.tolist()])
3738+
self.position_ids_cuda[:num_tokens].copy_(position_ids,
3739+
non_blocking=True)
3740+
3741+
# 2. Set up attention metadata
3742+
attn_metadata = self._set_up_attn_metadata(kv_cache_manager=None)
3743+
attn_metadata.seq_lens = seq_lens
3744+
attn_metadata.num_contexts = batch_size
3745+
attn_metadata.max_seq_len = self.max_seq_len
3746+
attn_metadata.request_ids = list(range(batch_size))
3747+
attn_metadata.prepare()
3748+
3749+
# 3. Build model-ready dict.
3750+
# **inputs goes FIRST so that the explicit buffer keys override the
3751+
# raw tensors. Extra keys pass through to the model's **kwargs
3752+
# are silently ignored if not in the model's forward() signature.
3753+
model_inputs = {
3754+
**inputs,
3755+
'attn_metadata': attn_metadata,
3756+
'input_ids': self.input_ids_cuda[:num_tokens],
3757+
'position_ids': self.position_ids_cuda[:num_tokens].unsqueeze(0),
3758+
}
3759+
3760+
return model_inputs
3761+
3762+
@torch.inference_mode()
3763+
@with_model_extra_attrs(lambda self: self.model.extra_attrs)
3764+
@nvtx_range("encoder_forward")
3765+
def encoder_forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
3766+
"""Direct tensor-level forward for encode-only path.
3767+
3768+
Bypasses ScheduledRequests/LlmRequest entirely. Takes a raw inputs
3769+
dict, prepares model-ready inputs via _prepare_encoder_inputs, and
3770+
calls _forward_step (which preserves torch.compile).
3771+
3772+
Args:
3773+
inputs: Dict with 'input_ids' and 'seq_lens' (required), plus
3774+
any model-specific kwargs (token_type_ids, inputs_embeds, etc.).
3775+
3776+
Returns:
3777+
Dict with 'logits' tensor and any other model outputs.
3778+
"""
3779+
model_inputs = self._prepare_encoder_inputs(inputs)
3780+
return self._forward_step(model_inputs,
3781+
gather_ids=None,
3782+
gather_context_logits=False)
3783+
37083784
@torch.inference_mode()
37093785
@with_model_extra_attrs(lambda self: self.model.extra_attrs)
37103786
def forward(self,
@@ -3790,8 +3866,10 @@ def forward(self,
37903866
return self._forward_step_mm_encoder_only(
37913867
inputs, scheduled_requests)
37923868
else:
3793-
return self._forward_step(inputs, gather_ids,
3794-
gather_context_logits)
3869+
return self._forward_step(
3870+
inputs,
3871+
gather_ids=gather_ids,
3872+
gather_context_logits=gather_context_logits)
37953873
with self.cuda_graph_runner.pad_batch(
37963874
scheduled_requests, resource_manager,
37973875
self.runtime_draft_len) as padded_requests:
@@ -3846,8 +3924,10 @@ def forward(self,
38463924
if not can_run_graph:
38473925
# Fallback to eager execution if graph was not used
38483926
with MoeLoadBalancerIterContext(moe_load_balancer):
3849-
outputs = self._forward_step(inputs, gather_ids,
3850-
gather_context_logits)
3927+
outputs = self._forward_step(
3928+
inputs,
3929+
gather_ids=gather_ids,
3930+
gather_context_logits=gather_context_logits)
38513931
else:
38523932
if self.cuda_graph_runner.needs_capture(key):
38533933

@@ -3919,7 +3999,8 @@ def model_forward(self, **kwargs):
39193999
@nvtx_range("_forward_step")
39204000
def _forward_step(self,
39214001
inputs: Dict[str, Any],
3922-
gather_ids: Optional[torch.Tensor],
4002+
*,
4003+
gather_ids: Optional[torch.Tensor] = None,
39234004
gather_context_logits: bool = False) -> Dict[str, Any]:
39244005
inputs = self._preprocess_inputs(inputs)
39254006
if inputs.get('spec_metadata', None):

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,57 @@ def get_guided_decoding_config(guided_decoding_backend: str,
225225
return guided_decoding_config
226226

227227

228+
def _load_config_and_create_checkpoint_loader(
229+
llm_args: TorchLlmArgs, checkpoint_dir: Optional[str] = None):
230+
torch.cuda.set_per_process_memory_fraction(1.0)
231+
checkpoint_loader = _construct_checkpoint_loader(llm_args.backend,
232+
llm_args.checkpoint_loader,
233+
llm_args.checkpoint_format)
234+
llm_args = ModelLoader.load_config_and_apply_defaults(
235+
checkpoint_dir, llm_args, checkpoint_loader)
236+
return llm_args, checkpoint_loader
237+
238+
239+
def create_encoder_executor(
240+
llm_args: TorchLlmArgs,
241+
checkpoint_dir: Optional[str] = None,
242+
):
243+
"""Create an EncoderExecutor for models using the encode-only path.
244+
245+
Handles model loading and model_engine creation, then wraps in a
246+
lightweight EncoderExecutor. Skips all decoder infrastructure
247+
(KV cache, scheduler, sampler, drafter, speculative decoding).
248+
249+
Args:
250+
llm_args: Configuration arguments.
251+
checkpoint_dir: Path to model checkpoint.
252+
253+
Returns:
254+
An EncoderExecutor instance ready for batch_forward() calls.
255+
"""
256+
from .encoder_executor import EncoderExecutor
257+
258+
llm_args, checkpoint_loader = _load_config_and_create_checkpoint_loader(
259+
llm_args, checkpoint_dir)
260+
261+
mapping = _get_mapping(llm_args.parallel_config.to_mapping())
262+
dist = Distributed.get(mapping)
263+
264+
model_engine = PyTorchModelEngine(
265+
model_path=checkpoint_dir,
266+
llm_args=llm_args,
267+
mapping=mapping,
268+
dist=dist,
269+
spec_config=None,
270+
checkpoint_loader=checkpoint_loader,
271+
)
272+
273+
return EncoderExecutor(
274+
model_engine=model_engine,
275+
dist=dist,
276+
)
277+
278+
228279
def create_py_executor(
229280
llm_args: TorchLlmArgs,
230281
checkpoint_dir: Optional[str] = None,
@@ -250,13 +301,8 @@ def create_py_executor(
250301
"""
251302

252303
skip_est = os.environ.get("TRTLLM_SKIP_KV_CACHE_ESTIMATION", '0') == '1'
253-
torch.cuda.set_per_process_memory_fraction(1.0)
254-
# Apply model-specific defaults early, before destructuring llm_args fields
255-
checkpoint_loader = _construct_checkpoint_loader(llm_args.backend,
256-
llm_args.checkpoint_loader,
257-
llm_args.checkpoint_format)
258-
llm_args = ModelLoader.load_config_and_apply_defaults(
259-
checkpoint_dir, llm_args, checkpoint_loader)
304+
llm_args, checkpoint_loader = _load_config_and_create_checkpoint_loader(
305+
llm_args, checkpoint_dir)
260306

261307
garbage_collection_gen0_threshold = llm_args.garbage_collection_gen0_threshold
262308
lora_config = llm_args.lora_config

0 commit comments

Comments
 (0)