@@ -497,6 +497,379 @@ def hook(module, args, kwargs):
497497 }
498498
499499
500+ class VLMVLLMBackend (BaseBackend ):
501+ """VLM vLLM backend, use vLLM for inference and extract hidden states.
502+
503+ Register forward hook on vLLM model's language_model to capture
504+ inputs_embeds and position_ids, and extract hidden states via apply_model.
505+
506+ Supported model types:
507+ - qwen3_vl: Qwen3-VL series vision-language models
508+ - hunyuan_vl: HunYuan-VL series vision-language models
509+ """
510+
511+ SUPPORT_MODEL_TYPE = ["qwen3_vl" , "hunyuan_vl" ]
512+
513+ def load_model (self ) -> None :
514+ """Load VLM model using vLLM."""
515+ from vllm import LLM
516+
517+ if self .target_model_type is None or self .target_model_type not in self .SUPPORT_MODEL_TYPE :
518+ raise ValueError (
519+ f"{ self .target_model_type } is not supported. "
520+ f"Supported types: { self .SUPPORT_MODEL_TYPE } "
521+ )
522+
523+ # Extract vllm-related parameters from kwargs
524+ tp_size = self .kwargs .get ("tensor_parallel_size" , 1 )
525+ max_model_len = self .kwargs .get ("max_model_len" , 8192 )
526+ gpu_memory_utilization = self .kwargs .get ("gpu_memory_utilization" , 0.9 )
527+ enforce_eager = self .kwargs .get ("enforce_eager" , True )
528+ max_num_seqs = self .kwargs .get ("max_num_seqs" , 8 )
529+ distributed_executor_backend = self .kwargs .get ("distributed_executor_backend" , "mp" )
530+ limit_mm_per_prompt = self .kwargs .get ("limit_mm_per_prompt" , {"image" : 10 , "video" : 10 })
531+
532+ print_with_rank (f"Loading VLM model with vLLM backend: { self .model_path } " )
533+ print_with_rank (f" tensor_parallel_size={ tp_size } , max_model_len={ max_model_len } " )
534+
535+ self .model = LLM (
536+ model = self .model_path ,
537+ tensor_parallel_size = tp_size ,
538+ max_model_len = max_model_len ,
539+ gpu_memory_utilization = gpu_memory_utilization ,
540+ enforce_eager = enforce_eager ,
541+ max_num_seqs = max_num_seqs ,
542+ distributed_executor_backend = distributed_executor_backend ,
543+ trust_remote_code = True ,
544+ limit_mm_per_prompt = limit_mm_per_prompt ,
545+ )
546+ self .tokenizer = self .model .get_tokenizer ()
547+
548+ def _get_language_model_module_name (self ) -> str :
549+ """Return the language model sub-module name based on model type."""
550+ if self .target_model_type == "qwen3_vl" :
551+ return "language_model"
552+ elif self .target_model_type == "hunyuan_vl" :
553+ return "model"
554+ else :
555+ raise ValueError (f"Unsupported target model type: { self .target_model_type } " )
556+
557+ def _build_vllm_inputs (
558+ self ,
559+ input_ids : torch .Tensor ,
560+ attention_mask : Optional [torch .Tensor ],
561+ ** kwargs ,
562+ ) -> list :
563+ """Convert batched tensor inputs to a vLLM PromptType list.
564+
565+ vLLM does not accept batched tensor inputs; each sample must be
566+ converted to an independent prompt dict.
567+
568+ Args:
569+ input_ids: shape [batch_size, seq_len]
570+ attention_mask: shape [batch_size, seq_len], used to determine valid length
571+ **kwargs: may contain pixel_values, image_grid_thw, and other multimodal inputs
572+
573+ Returns:
574+ List of vLLM PromptType, one element per sample
575+ """
576+ from vllm import TokensPrompt
577+
578+ batch_size = input_ids .shape [0 ]
579+ pixel_values = kwargs .get ("pixel_values" , None )
580+ image_grid_thw = kwargs .get ("image_grid_thw" , None )
581+
582+ prompts = []
583+ for i in range (batch_size ):
584+ # Truncate to valid tokens based on attention_mask
585+ if attention_mask is not None :
586+ valid_len = int (attention_mask [i ].sum ().item ())
587+ ids = input_ids [i , :valid_len ].tolist ()
588+ else :
589+ ids = input_ids [i ].tolist ()
590+
591+ prompt : dict = {"prompt_token_ids" : ids }
592+
593+ # Attach multimodal data
594+ if pixel_values is not None :
595+ mm_data = {"image" : pixel_values [i ]}
596+ if image_grid_thw is not None :
597+ # image_grid_thw: [num_images, 3], take the row for the current sample
598+ mm_data ["image_grid_thw" ] = image_grid_thw [i : i + 1 ]
599+ prompt ["multi_modal_data" ] = mm_data
600+
601+ prompts .append (TokensPrompt (** prompt ))
602+
603+ return prompts
604+
605+ def get_hidden_states_and_logits (
606+ self ,
607+ input_ids : torch .Tensor ,
608+ attention_mask : Optional [torch .Tensor ] = None ,
609+ ** kwargs ,
610+ ) -> Tuple [torch .Tensor , ...]:
611+ """get hidden states and logits from vLLM backend.
612+
613+ Args:
614+ input_ids: shape [batch_size, seq_len]
615+ attention_mask: shape [batch_size, seq_len]
616+ **kwargs: pixel_values, image_grid_thw, aux_hidden_states_layer_ids
617+
618+ Returns:
619+ Tuple of (hidden_states, logits, inputs_embeds, position_ids)
620+ """
621+ raise NotImplementedError (
622+ "get_hidden_states_and_logits is not implemented for VLMVLLMBackend. "
623+ "Please use get_aux_and_target_hiddens instead."
624+ )
625+
626+ def get_aux_and_target_hiddens (
627+ self ,
628+ input_ids : torch .Tensor ,
629+ attention_mask : Optional [torch .Tensor ] = None ,
630+ ** kwargs ,
631+ ) -> dict :
632+ """Extract auxiliary and target hidden states using the vLLM backend.
633+
634+ Registers forward hooks inside vLLM workers via apply_model, stores the
635+ collected hidden states as a temporary attribute on the model, then reads
636+ and cleans up the data after inference via a second apply_model call.
637+
638+ Note: This method requires vLLM to run with enforce_eager=True (no CUDA
639+ graph) so that forward hooks fire correctly.
640+
641+ Args:
642+ input_ids: Input token IDs, shape [batch_size, seq_len]
643+ attention_mask: Attention mask, shape [batch_size, seq_len]
644+ **kwargs: May contain:
645+ - pixel_values: image pixel values
646+ - image_grid_thw: image grid dimensions
647+ - aux_hidden_states_layer_ids: list of auxiliary layer indices
648+
649+ Returns:
650+ dict containing:
651+ - hidden_states: concatenated auxiliary hidden states,
652+ shape [batch_size, seq_len, hidden_size * 3]
653+ - target_hiddens: final-layer hidden states,
654+ shape [batch_size, seq_len, hidden_size]
655+ - inputs_embeds: input embeddings,
656+ shape [batch_size, seq_len, hidden_size]
657+ - position_ids: position encoding IDs
658+ """
659+ lm_module_name = self ._get_language_model_module_name ()
660+ aux_layer_ids = kwargs .get ("aux_hidden_states_layer_ids" , None )
661+ # Temporary attribute name used to store hook data inside the worker
662+ _CACHE_ATTR = "_vlm_vllm_hook_cache"
663+
664+ def setup_hooks_fn (model ):
665+ """Register hooks inside the vLLM worker and store data as a model attribute.
666+
667+ When TP > 1, each worker only holds a hidden_size / TP slice.
668+ We use dist.all_gather inside the worker to merge slices before storing.
669+ Only the TP rank-0 worker stores the complete data; others store None.
670+ """
671+ import torch .distributed as worker_dist
672+
673+ handles = []
674+ # Initialise cache
675+ setattr (
676+ model ,
677+ _CACHE_ATTR ,
678+ {
679+ "all_hidden_states" : [],
680+ "inputs_embeds" : None ,
681+ "position_ids" : None ,
682+ },
683+ )
684+ cache = getattr (model , _CACHE_ATTR )
685+
686+ # Determine whether this worker is TP rank 0 (responsible for storing complete data)
687+ tp_rank = 0
688+ if worker_dist .is_initialized ():
689+ tp_rank = worker_dist .get_rank ()
690+
691+ def _all_gather_hidden (hidden : torch .Tensor ) -> torch .Tensor :
692+ """All-gather hidden states within the TP group, concatenating on the last dim.
693+
694+ TP splits hidden_size evenly across ranks; after all_gather we
695+ concatenate along dim=-1 to restore the full hidden states.
696+ """
697+ if not worker_dist .is_initialized () or worker_dist .get_world_size () <= 1 :
698+ return hidden
699+ world_size = worker_dist .get_world_size ()
700+ # hidden: [batch, seq_len, hidden_size_per_tp]
701+ gathered = [torch .empty_like (hidden ) for _ in range (world_size )]
702+ worker_dist .all_gather (gathered , hidden )
703+ # Concatenate along the last dimension to restore full hidden_size
704+ return torch .cat (gathered , dim = - 1 )
705+
706+ # Retrieve the language model sub-module
707+ lm = getattr (model , lm_module_name , None )
708+ if lm is None :
709+ raise AttributeError (
710+ f"Model does not have attribute '{ lm_module_name } '. "
711+ f"Available attributes: { list (model .__dict__ .keys ())} "
712+ )
713+
714+ # Hook 1: capture inputs_embeds and position_ids
715+ def pre_hook (module , args , hook_kwargs ):
716+ if "inputs_embeds" in hook_kwargs and hook_kwargs ["inputs_embeds" ] is not None :
717+ embeds = hook_kwargs ["inputs_embeds" ].clone ().detach ()
718+ # When TP > 1, inputs_embeds is also split along hidden_size; all_gather needed
719+ embeds = _all_gather_hidden (embeds )
720+ if tp_rank == 0 :
721+ cache ["inputs_embeds" ] = embeds .cpu ()
722+ if "position_ids" in hook_kwargs and hook_kwargs ["position_ids" ] is not None :
723+ # position_ids is not split along hidden_size; all ranks hold identical data
724+ if tp_rank == 0 :
725+ cache ["position_ids" ] = hook_kwargs ["position_ids" ].clone ().detach ().cpu ()
726+ return args , hook_kwargs
727+
728+ h = lm .register_forward_pre_hook (pre_hook , with_kwargs = True )
729+ handles .append (h )
730+
731+ # Hook 2: register a post-hook on each decoder layer to collect hidden states
732+ layers = None
733+ for attr in ["layers" , "decoder_layers" , "h" , "blocks" ]:
734+ layers = getattr (lm , attr , None )
735+ if layers is not None :
736+ break
737+ # If lm itself has no layers attribute, try lm.model
738+ if layers is None and hasattr (lm , "model" ):
739+ for attr in ["layers" , "decoder_layers" , "h" , "blocks" ]:
740+ layers = getattr (lm .model , attr , None )
741+ if layers is not None :
742+ break
743+
744+ if layers is not None :
745+ cache ["all_hidden_states" ] = [None ] * len (layers )
746+
747+ for layer_idx , layer in enumerate (layers ):
748+
749+ def make_layer_hook (idx ):
750+ def layer_hook (module , args , output ):
751+ hidden = output [0 ] if isinstance (output , tuple ) else output
752+ hidden = hidden .clone ().detach ()
753+ # TP > 1: all_gather to merge hidden_size slices from each rank
754+ hidden = _all_gather_hidden (hidden )
755+ # Only TP rank 0 stores the complete data
756+ if tp_rank == 0 :
757+ cache ["all_hidden_states" ][idx ] = hidden .cpu ()
758+ return output
759+
760+ return layer_hook
761+
762+ h = layer .register_forward_hook (make_layer_hook (layer_idx ))
763+ handles .append (h )
764+
765+ # Store handles in cache for later cleanup
766+ cache ["_handles" ] = handles
767+ return True
768+
769+ def collect_and_cleanup_fn (model ):
770+ """Read hook data from the model's temporary attribute and clean up."""
771+ cache = getattr (model , _CACHE_ATTR , None )
772+ if cache is None :
773+ return None
774+ # Remove all hooks
775+ for h in cache .get ("_handles" , []):
776+ h .remove ()
777+ # Read collected data
778+ result = {
779+ "all_hidden_states" : cache ["all_hidden_states" ],
780+ "inputs_embeds" : cache ["inputs_embeds" ],
781+ "position_ids" : cache ["position_ids" ],
782+ }
783+ # Clean up temporary attribute
784+ delattr (model , _CACHE_ATTR )
785+ return result
786+
787+ # Register hooks inside the worker
788+ self .model .apply_model (setup_hooks_fn )
789+
790+ try :
791+ # Build vLLM inputs
792+ prompts = self ._build_vllm_inputs (input_ids , attention_mask , ** kwargs )
793+
794+ # Run inference (vLLM internally triggers the hooks)
795+ from vllm import SamplingParams
796+
797+ sampling_params = SamplingParams (
798+ temperature = 0 ,
799+ max_tokens = 1 ,
800+ logprobs = 0 ,
801+ )
802+ _ = self .model .generate (prompts , sampling_params = sampling_params )
803+
804+ finally :
805+ # Read data from the worker and clean up hooks
806+ worker_results = self .model .apply_model (collect_and_cleanup_fn )
807+
808+ # apply_model returns a list of results, one per worker.
809+ # When TP > 1, only the TP rank-0 worker holds complete data
810+ # (other workers have None elements in all_hidden_states).
811+ # Pick the first result whose all_hidden_states contains non-None entries.
812+ collected = None
813+ for result in worker_results :
814+ if (
815+ result is not None
816+ and result .get ("all_hidden_states" )
817+ and any (h is not None for h in result ["all_hidden_states" ])
818+ ):
819+ collected = result
820+ break
821+
822+ if collected is None :
823+ raise RuntimeError (
824+ "Failed to collect hidden states from vLLM model. "
825+ "apply_model returned no valid results."
826+ )
827+
828+ all_hs = collected ["all_hidden_states" ]
829+ inputs_embeds = collected ["inputs_embeds" ]
830+ position_ids = collected ["position_ids" ]
831+
832+ if not all_hs or all (h is None for h in all_hs ):
833+ raise RuntimeError (
834+ "Failed to collect hidden states from vLLM model. "
835+ "Please check that the model architecture is supported and "
836+ "enforce_eager=True is set."
837+ )
838+
839+ # Determine auxiliary layer indices
840+ if aux_layer_ids is None :
841+ num_layers = len (all_hs )
842+ aux_layer_ids = self ._get_default_aux_layer_ids (num_layers )
843+
844+ # Extract and concatenate auxiliary-layer hidden states
845+ selected_hiddens = [all_hs [layer_id ] for layer_id in aux_layer_ids ]
846+ aux_hidden_states = torch .cat (selected_hiddens , dim = - 1 )
847+
848+ # Final-layer hidden states
849+ target_hidden_states = all_hs [- 1 ]
850+
851+ # Handle position_ids for hunyuan_vl (take the first dimension)
852+ if self .target_model_type == "hunyuan_vl" and position_ids is not None :
853+ if position_ids .dim () == 3 :
854+ position_ids = position_ids [:, 0 , :]
855+
856+ # Move results to the same device as input_ids
857+ device = input_ids .device
858+ aux_hidden_states = aux_hidden_states .to (device )
859+ target_hidden_states = target_hidden_states .to (device )
860+ if inputs_embeds is not None :
861+ inputs_embeds = inputs_embeds .to (device )
862+ if position_ids is not None :
863+ position_ids = position_ids .to (device )
864+
865+ return {
866+ "hidden_states" : aux_hidden_states ,
867+ "target_hiddens" : target_hidden_states ,
868+ "inputs_embeds" : inputs_embeds ,
869+ "position_ids" : position_ids ,
870+ }
871+
872+
500873class AudioTransformersBackend (BaseBackend ):
501874 """Audio HuggingFace Transformers backend"""
502875
@@ -768,6 +1141,7 @@ class TargetModelWrapper:
7681141 ("hf" , "VLM" ): VLMTransformersBackend ,
7691142 ("hf" , "TTS" ): TTSTransformersBackend ,
7701143 ("hf" , "Audio" ): AudioTransformersBackend ,
1144+ ("vllm" , "VLM" ): VLMVLLMBackend ,
7711145 }
7721146
7731147 def __init__ (
@@ -916,6 +1290,9 @@ def create_target_model(
9161290 # Add backend-specific configuration
9171291 if backend == "hf" :
9181292 kwargs ["torch_dtype" ] = torch_dtype
1293+ elif backend == "vllm" :
1294+ # vllm backend does not use the torch_dtype parameter; other extra_kwargs are kept
1295+ pass
9191296 else :
9201297 raise ValueError (
9211298 f"Unsupported backend: '{ backend } '. "
0 commit comments