@@ -3705,6 +3705,82 @@ def _prepare_inputs(
37053705 num_accepted_tokens_device , req_id_to_old_request , resource_manager ,
37063706 maybe_graph )
37073707
3708+ def _prepare_encoder_inputs (self , inputs : Dict [str , Any ]) -> Dict [str , Any ]:
3709+ """Prepare model-ready inputs dict for encode-only path.
3710+
3711+ Encoder equivalent of _prepare_tp_inputs + _preprocess_inputs.
3712+ Consumes raw inputs dict, copies to pre-allocated CUDA buffers,
3713+ sets up attention metadata, and returns model-ready dict.
3714+
3715+ Args:
3716+ inputs: Dict with required keys 'input_ids' ([total_tokens]) and
3717+ 'seq_lens' ([batch_size]). Optional 'position_ids'
3718+ ([total_tokens]). Any additional keys (token_type_ids,
3719+ inputs_embeds, etc.) are passed through to the model's
3720+ forward() via **kwargs.
3721+ """
3722+ token_ids = inputs ['input_ids' ]
3723+ seq_lens = inputs ['seq_lens' ]
3724+ position_ids = inputs .get ('position_ids' )
3725+ num_tokens = token_ids .shape [0 ]
3726+ batch_size = seq_lens .shape [0 ]
3727+
3728+ assert num_tokens <= self .max_num_tokens , (
3729+ f"num_tokens ({ num_tokens } ) exceeds max_num_tokens "
3730+ f"({ self .max_num_tokens } ). Reduce batch size or sequence lengths." )
3731+
3732+ # 1. Copy to pre-allocated CUDA buffers
3733+ self .input_ids_cuda [:num_tokens ].copy_ (token_ids , non_blocking = True )
3734+ if position_ids is None :
3735+ # Auto-generate packed position IDs: [0..n1-1, 0..n2-1, ...]
3736+ position_ids = torch .cat (
3737+ [torch .arange (s , dtype = torch .int32 ) for s in seq_lens .tolist ()])
3738+ self .position_ids_cuda [:num_tokens ].copy_ (position_ids ,
3739+ non_blocking = True )
3740+
3741+ # 2. Set up attention metadata
3742+ attn_metadata = self ._set_up_attn_metadata (kv_cache_manager = None )
3743+ attn_metadata .seq_lens = seq_lens
3744+ attn_metadata .num_contexts = batch_size
3745+ attn_metadata .max_seq_len = self .max_seq_len
3746+ attn_metadata .request_ids = list (range (batch_size ))
3747+ attn_metadata .prepare ()
3748+
3749+ # 3. Build model-ready dict.
3750+ # **inputs goes FIRST so that the explicit buffer keys override the
3751+ # raw tensors. Extra keys pass through to the model's **kwargs
3752+ # are silently ignored if not in the model's forward() signature.
3753+ model_inputs = {
3754+ ** inputs ,
3755+ 'attn_metadata' : attn_metadata ,
3756+ 'input_ids' : self .input_ids_cuda [:num_tokens ],
3757+ 'position_ids' : self .position_ids_cuda [:num_tokens ].unsqueeze (0 ),
3758+ }
3759+
3760+ return model_inputs
3761+
3762+ @torch .inference_mode ()
3763+ @with_model_extra_attrs (lambda self : self .model .extra_attrs )
3764+ @nvtx_range ("encoder_forward" )
3765+ def encoder_forward (self , inputs : Dict [str , Any ]) -> Dict [str , Any ]:
3766+ """Direct tensor-level forward for encode-only path.
3767+
3768+ Bypasses ScheduledRequests/LlmRequest entirely. Takes a raw inputs
3769+ dict, prepares model-ready inputs via _prepare_encoder_inputs, and
3770+ calls _forward_step (which preserves torch.compile).
3771+
3772+ Args:
3773+ inputs: Dict with 'input_ids' and 'seq_lens' (required), plus
3774+ any model-specific kwargs (token_type_ids, inputs_embeds, etc.).
3775+
3776+ Returns:
3777+ Dict with 'logits' tensor and any other model outputs.
3778+ """
3779+ model_inputs = self ._prepare_encoder_inputs (inputs )
3780+ return self ._forward_step (model_inputs ,
3781+ gather_ids = None ,
3782+ gather_context_logits = False )
3783+
37083784 @torch .inference_mode ()
37093785 @with_model_extra_attrs (lambda self : self .model .extra_attrs )
37103786 def forward (self ,
@@ -3790,8 +3866,10 @@ def forward(self,
37903866 return self ._forward_step_mm_encoder_only (
37913867 inputs , scheduled_requests )
37923868 else :
3793- return self ._forward_step (inputs , gather_ids ,
3794- gather_context_logits )
3869+ return self ._forward_step (
3870+ inputs ,
3871+ gather_ids = gather_ids ,
3872+ gather_context_logits = gather_context_logits )
37953873 with self .cuda_graph_runner .pad_batch (
37963874 scheduled_requests , resource_manager ,
37973875 self .runtime_draft_len ) as padded_requests :
@@ -3846,8 +3924,10 @@ def forward(self,
38463924 if not can_run_graph :
38473925 # Fallback to eager execution if graph was not used
38483926 with MoeLoadBalancerIterContext (moe_load_balancer ):
3849- outputs = self ._forward_step (inputs , gather_ids ,
3850- gather_context_logits )
3927+ outputs = self ._forward_step (
3928+ inputs ,
3929+ gather_ids = gather_ids ,
3930+ gather_context_logits = gather_context_logits )
38513931 else :
38523932 if self .cuda_graph_runner .needs_capture (key ):
38533933
@@ -3919,7 +3999,8 @@ def model_forward(self, **kwargs):
39193999 @nvtx_range ("_forward_step" )
39204000 def _forward_step (self ,
39214001 inputs : Dict [str , Any ],
3922- gather_ids : Optional [torch .Tensor ],
4002+ * ,
4003+ gather_ids : Optional [torch .Tensor ] = None ,
39234004 gather_context_logits : bool = False ) -> Dict [str , Any ]:
39244005 inputs = self ._preprocess_inputs (inputs )
39254006 if inputs .get ('spec_metadata' , None ):
0 commit comments