Skip to content

Commit 95824fc

Browse files
committed
feat: support extract hidden states with vllm backend
1 parent 028f7ab commit 95824fc

1 file changed

Lines changed: 377 additions & 0 deletions

File tree

angelslim/compressor/speculative/train/models/target/target_model_wrapper.py

Lines changed: 377 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,379 @@ def hook(module, args, kwargs):
497497
}
498498

499499

500+
class VLMVLLMBackend(BaseBackend):
501+
"""VLM vLLM backend, use vLLM for inference and extract hidden states.
502+
503+
Register forward hook on vLLM model's language_model to capture
504+
inputs_embeds and position_ids, and extract hidden states via apply_model.
505+
506+
Supported model types:
507+
- qwen3_vl: Qwen3-VL series vision-language models
508+
- hunyuan_vl: HunYuan-VL series vision-language models
509+
"""
510+
511+
SUPPORT_MODEL_TYPE = ["qwen3_vl", "hunyuan_vl"]
512+
513+
def load_model(self) -> None:
514+
"""Load VLM model using vLLM."""
515+
from vllm import LLM
516+
517+
if self.target_model_type is None or self.target_model_type not in self.SUPPORT_MODEL_TYPE:
518+
raise ValueError(
519+
f"{self.target_model_type} is not supported. "
520+
f"Supported types: {self.SUPPORT_MODEL_TYPE}"
521+
)
522+
523+
# Extract vllm-related parameters from kwargs
524+
tp_size = self.kwargs.get("tensor_parallel_size", 1)
525+
max_model_len = self.kwargs.get("max_model_len", 8192)
526+
gpu_memory_utilization = self.kwargs.get("gpu_memory_utilization", 0.9)
527+
enforce_eager = self.kwargs.get("enforce_eager", True)
528+
max_num_seqs = self.kwargs.get("max_num_seqs", 8)
529+
distributed_executor_backend = self.kwargs.get("distributed_executor_backend", "mp")
530+
limit_mm_per_prompt = self.kwargs.get("limit_mm_per_prompt", {"image": 10, "video": 10})
531+
532+
print_with_rank(f"Loading VLM model with vLLM backend: {self.model_path}")
533+
print_with_rank(f" tensor_parallel_size={tp_size}, max_model_len={max_model_len}")
534+
535+
self.model = LLM(
536+
model=self.model_path,
537+
tensor_parallel_size=tp_size,
538+
max_model_len=max_model_len,
539+
gpu_memory_utilization=gpu_memory_utilization,
540+
enforce_eager=enforce_eager,
541+
max_num_seqs=max_num_seqs,
542+
distributed_executor_backend=distributed_executor_backend,
543+
trust_remote_code=True,
544+
limit_mm_per_prompt=limit_mm_per_prompt,
545+
)
546+
self.tokenizer = self.model.get_tokenizer()
547+
548+
def _get_language_model_module_name(self) -> str:
549+
"""Return the language model sub-module name based on model type."""
550+
if self.target_model_type == "qwen3_vl":
551+
return "language_model"
552+
elif self.target_model_type == "hunyuan_vl":
553+
return "model"
554+
else:
555+
raise ValueError(f"Unsupported target model type: {self.target_model_type}")
556+
557+
def _build_vllm_inputs(
558+
self,
559+
input_ids: torch.Tensor,
560+
attention_mask: Optional[torch.Tensor],
561+
**kwargs,
562+
) -> list:
563+
"""Convert batched tensor inputs to a vLLM PromptType list.
564+
565+
vLLM does not accept batched tensor inputs; each sample must be
566+
converted to an independent prompt dict.
567+
568+
Args:
569+
input_ids: shape [batch_size, seq_len]
570+
attention_mask: shape [batch_size, seq_len], used to determine valid length
571+
**kwargs: may contain pixel_values, image_grid_thw, and other multimodal inputs
572+
573+
Returns:
574+
List of vLLM PromptType, one element per sample
575+
"""
576+
from vllm import TokensPrompt
577+
578+
batch_size = input_ids.shape[0]
579+
pixel_values = kwargs.get("pixel_values", None)
580+
image_grid_thw = kwargs.get("image_grid_thw", None)
581+
582+
prompts = []
583+
for i in range(batch_size):
584+
# Truncate to valid tokens based on attention_mask
585+
if attention_mask is not None:
586+
valid_len = int(attention_mask[i].sum().item())
587+
ids = input_ids[i, :valid_len].tolist()
588+
else:
589+
ids = input_ids[i].tolist()
590+
591+
prompt: dict = {"prompt_token_ids": ids}
592+
593+
# Attach multimodal data
594+
if pixel_values is not None:
595+
mm_data = {"image": pixel_values[i]}
596+
if image_grid_thw is not None:
597+
# image_grid_thw: [num_images, 3], take the row for the current sample
598+
mm_data["image_grid_thw"] = image_grid_thw[i : i + 1]
599+
prompt["multi_modal_data"] = mm_data
600+
601+
prompts.append(TokensPrompt(**prompt))
602+
603+
return prompts
604+
605+
def get_hidden_states_and_logits(
606+
self,
607+
input_ids: torch.Tensor,
608+
attention_mask: Optional[torch.Tensor] = None,
609+
**kwargs,
610+
) -> Tuple[torch.Tensor, ...]:
611+
"""get hidden states and logits from vLLM backend.
612+
613+
Args:
614+
input_ids: shape [batch_size, seq_len]
615+
attention_mask: shape [batch_size, seq_len]
616+
**kwargs: pixel_values, image_grid_thw, aux_hidden_states_layer_ids
617+
618+
Returns:
619+
Tuple of (hidden_states, logits, inputs_embeds, position_ids)
620+
"""
621+
raise NotImplementedError(
622+
"get_hidden_states_and_logits is not implemented for VLMVLLMBackend. "
623+
"Please use get_aux_and_target_hiddens instead."
624+
)
625+
626+
def get_aux_and_target_hiddens(
627+
self,
628+
input_ids: torch.Tensor,
629+
attention_mask: Optional[torch.Tensor] = None,
630+
**kwargs,
631+
) -> dict:
632+
"""Extract auxiliary and target hidden states using the vLLM backend.
633+
634+
Registers forward hooks inside vLLM workers via apply_model, stores the
635+
collected hidden states as a temporary attribute on the model, then reads
636+
and cleans up the data after inference via a second apply_model call.
637+
638+
Note: This method requires vLLM to run with enforce_eager=True (no CUDA
639+
graph) so that forward hooks fire correctly.
640+
641+
Args:
642+
input_ids: Input token IDs, shape [batch_size, seq_len]
643+
attention_mask: Attention mask, shape [batch_size, seq_len]
644+
**kwargs: May contain:
645+
- pixel_values: image pixel values
646+
- image_grid_thw: image grid dimensions
647+
- aux_hidden_states_layer_ids: list of auxiliary layer indices
648+
649+
Returns:
650+
dict containing:
651+
- hidden_states: concatenated auxiliary hidden states,
652+
shape [batch_size, seq_len, hidden_size * 3]
653+
- target_hiddens: final-layer hidden states,
654+
shape [batch_size, seq_len, hidden_size]
655+
- inputs_embeds: input embeddings,
656+
shape [batch_size, seq_len, hidden_size]
657+
- position_ids: position encoding IDs
658+
"""
659+
lm_module_name = self._get_language_model_module_name()
660+
aux_layer_ids = kwargs.get("aux_hidden_states_layer_ids", None)
661+
# Temporary attribute name used to store hook data inside the worker
662+
_CACHE_ATTR = "_vlm_vllm_hook_cache"
663+
664+
def setup_hooks_fn(model):
665+
"""Register hooks inside the vLLM worker and store data as a model attribute.
666+
667+
When TP > 1, each worker only holds a hidden_size / TP slice.
668+
We use dist.all_gather inside the worker to merge slices before storing.
669+
Only the TP rank-0 worker stores the complete data; others store None.
670+
"""
671+
import torch.distributed as worker_dist
672+
673+
handles = []
674+
# Initialise cache
675+
setattr(
676+
model,
677+
_CACHE_ATTR,
678+
{
679+
"all_hidden_states": [],
680+
"inputs_embeds": None,
681+
"position_ids": None,
682+
},
683+
)
684+
cache = getattr(model, _CACHE_ATTR)
685+
686+
# Determine whether this worker is TP rank 0 (responsible for storing complete data)
687+
tp_rank = 0
688+
if worker_dist.is_initialized():
689+
tp_rank = worker_dist.get_rank()
690+
691+
def _all_gather_hidden(hidden: torch.Tensor) -> torch.Tensor:
692+
"""All-gather hidden states within the TP group, concatenating on the last dim.
693+
694+
TP splits hidden_size evenly across ranks; after all_gather we
695+
concatenate along dim=-1 to restore the full hidden states.
696+
"""
697+
if not worker_dist.is_initialized() or worker_dist.get_world_size() <= 1:
698+
return hidden
699+
world_size = worker_dist.get_world_size()
700+
# hidden: [batch, seq_len, hidden_size_per_tp]
701+
gathered = [torch.empty_like(hidden) for _ in range(world_size)]
702+
worker_dist.all_gather(gathered, hidden)
703+
# Concatenate along the last dimension to restore full hidden_size
704+
return torch.cat(gathered, dim=-1)
705+
706+
# Retrieve the language model sub-module
707+
lm = getattr(model, lm_module_name, None)
708+
if lm is None:
709+
raise AttributeError(
710+
f"Model does not have attribute '{lm_module_name}'. "
711+
f"Available attributes: {list(model.__dict__.keys())}"
712+
)
713+
714+
# Hook 1: capture inputs_embeds and position_ids
715+
def pre_hook(module, args, hook_kwargs):
716+
if "inputs_embeds" in hook_kwargs and hook_kwargs["inputs_embeds"] is not None:
717+
embeds = hook_kwargs["inputs_embeds"].clone().detach()
718+
# When TP > 1, inputs_embeds is also split along hidden_size; all_gather needed
719+
embeds = _all_gather_hidden(embeds)
720+
if tp_rank == 0:
721+
cache["inputs_embeds"] = embeds.cpu()
722+
if "position_ids" in hook_kwargs and hook_kwargs["position_ids"] is not None:
723+
# position_ids is not split along hidden_size; all ranks hold identical data
724+
if tp_rank == 0:
725+
cache["position_ids"] = hook_kwargs["position_ids"].clone().detach().cpu()
726+
return args, hook_kwargs
727+
728+
h = lm.register_forward_pre_hook(pre_hook, with_kwargs=True)
729+
handles.append(h)
730+
731+
# Hook 2: register a post-hook on each decoder layer to collect hidden states
732+
layers = None
733+
for attr in ["layers", "decoder_layers", "h", "blocks"]:
734+
layers = getattr(lm, attr, None)
735+
if layers is not None:
736+
break
737+
# If lm itself has no layers attribute, try lm.model
738+
if layers is None and hasattr(lm, "model"):
739+
for attr in ["layers", "decoder_layers", "h", "blocks"]:
740+
layers = getattr(lm.model, attr, None)
741+
if layers is not None:
742+
break
743+
744+
if layers is not None:
745+
cache["all_hidden_states"] = [None] * len(layers)
746+
747+
for layer_idx, layer in enumerate(layers):
748+
749+
def make_layer_hook(idx):
750+
def layer_hook(module, args, output):
751+
hidden = output[0] if isinstance(output, tuple) else output
752+
hidden = hidden.clone().detach()
753+
# TP > 1: all_gather to merge hidden_size slices from each rank
754+
hidden = _all_gather_hidden(hidden)
755+
# Only TP rank 0 stores the complete data
756+
if tp_rank == 0:
757+
cache["all_hidden_states"][idx] = hidden.cpu()
758+
return output
759+
760+
return layer_hook
761+
762+
h = layer.register_forward_hook(make_layer_hook(layer_idx))
763+
handles.append(h)
764+
765+
# Store handles in cache for later cleanup
766+
cache["_handles"] = handles
767+
return True
768+
769+
def collect_and_cleanup_fn(model):
770+
"""Read hook data from the model's temporary attribute and clean up."""
771+
cache = getattr(model, _CACHE_ATTR, None)
772+
if cache is None:
773+
return None
774+
# Remove all hooks
775+
for h in cache.get("_handles", []):
776+
h.remove()
777+
# Read collected data
778+
result = {
779+
"all_hidden_states": cache["all_hidden_states"],
780+
"inputs_embeds": cache["inputs_embeds"],
781+
"position_ids": cache["position_ids"],
782+
}
783+
# Clean up temporary attribute
784+
delattr(model, _CACHE_ATTR)
785+
return result
786+
787+
# Register hooks inside the worker
788+
self.model.apply_model(setup_hooks_fn)
789+
790+
try:
791+
# Build vLLM inputs
792+
prompts = self._build_vllm_inputs(input_ids, attention_mask, **kwargs)
793+
794+
# Run inference (vLLM internally triggers the hooks)
795+
from vllm import SamplingParams
796+
797+
sampling_params = SamplingParams(
798+
temperature=0,
799+
max_tokens=1,
800+
logprobs=0,
801+
)
802+
_ = self.model.generate(prompts, sampling_params=sampling_params)
803+
804+
finally:
805+
# Read data from the worker and clean up hooks
806+
worker_results = self.model.apply_model(collect_and_cleanup_fn)
807+
808+
# apply_model returns a list of results, one per worker.
809+
# When TP > 1, only the TP rank-0 worker holds complete data
810+
# (other workers have None elements in all_hidden_states).
811+
# Pick the first result whose all_hidden_states contains non-None entries.
812+
collected = None
813+
for result in worker_results:
814+
if (
815+
result is not None
816+
and result.get("all_hidden_states")
817+
and any(h is not None for h in result["all_hidden_states"])
818+
):
819+
collected = result
820+
break
821+
822+
if collected is None:
823+
raise RuntimeError(
824+
"Failed to collect hidden states from vLLM model. "
825+
"apply_model returned no valid results."
826+
)
827+
828+
all_hs = collected["all_hidden_states"]
829+
inputs_embeds = collected["inputs_embeds"]
830+
position_ids = collected["position_ids"]
831+
832+
if not all_hs or all(h is None for h in all_hs):
833+
raise RuntimeError(
834+
"Failed to collect hidden states from vLLM model. "
835+
"Please check that the model architecture is supported and "
836+
"enforce_eager=True is set."
837+
)
838+
839+
# Determine auxiliary layer indices
840+
if aux_layer_ids is None:
841+
num_layers = len(all_hs)
842+
aux_layer_ids = self._get_default_aux_layer_ids(num_layers)
843+
844+
# Extract and concatenate auxiliary-layer hidden states
845+
selected_hiddens = [all_hs[layer_id] for layer_id in aux_layer_ids]
846+
aux_hidden_states = torch.cat(selected_hiddens, dim=-1)
847+
848+
# Final-layer hidden states
849+
target_hidden_states = all_hs[-1]
850+
851+
# Handle position_ids for hunyuan_vl (take the first dimension)
852+
if self.target_model_type == "hunyuan_vl" and position_ids is not None:
853+
if position_ids.dim() == 3:
854+
position_ids = position_ids[:, 0, :]
855+
856+
# Move results to the same device as input_ids
857+
device = input_ids.device
858+
aux_hidden_states = aux_hidden_states.to(device)
859+
target_hidden_states = target_hidden_states.to(device)
860+
if inputs_embeds is not None:
861+
inputs_embeds = inputs_embeds.to(device)
862+
if position_ids is not None:
863+
position_ids = position_ids.to(device)
864+
865+
return {
866+
"hidden_states": aux_hidden_states,
867+
"target_hiddens": target_hidden_states,
868+
"inputs_embeds": inputs_embeds,
869+
"position_ids": position_ids,
870+
}
871+
872+
500873
class AudioTransformersBackend(BaseBackend):
501874
"""Audio HuggingFace Transformers backend"""
502875

@@ -768,6 +1141,7 @@ class TargetModelWrapper:
7681141
("hf", "VLM"): VLMTransformersBackend,
7691142
("hf", "TTS"): TTSTransformersBackend,
7701143
("hf", "Audio"): AudioTransformersBackend,
1144+
("vllm", "VLM"): VLMVLLMBackend,
7711145
}
7721146

7731147
def __init__(
@@ -916,6 +1290,9 @@ def create_target_model(
9161290
# Add backend-specific configuration
9171291
if backend == "hf":
9181292
kwargs["torch_dtype"] = torch_dtype
1293+
elif backend == "vllm":
1294+
# vllm backend does not use the torch_dtype parameter; other extra_kwargs are kept
1295+
pass
9191296
else:
9201297
raise ValueError(
9211298
f"Unsupported backend: '{backend}'. "

0 commit comments

Comments
 (0)