-
Notifications
You must be signed in to change notification settings - Fork 233
[OpenVINO] Support eagle3 draft model for Qwen3-VL model #1679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
openvino-agent
wants to merge
3
commits into
huggingface:main
Choose a base branch
from
openvino-agent:qwen3_vl_eagle3
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7901,6 +7901,148 @@ def forward( | |
| ) | ||
|
|
||
|
|
||
| if is_transformers_version(">=", "4.57"): | ||
| from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLTextRotaryEmbedding | ||
|
|
||
| class QwenVLEagle3Model(LlamaEagle3Model): | ||
| """ | ||
| Eagle-3 draft model with Qwen3-VL MRoPE for VLM speculative decoding. | ||
|
|
||
| Extends LlamaEagle3Model by replacing the standard rotary embedding with | ||
| Qwen3VLTextRotaryEmbedding, which supports interleaved multimodal RoPE | ||
| (MRoPE). This allows the draft model to handle position IDs compatible | ||
| with Qwen3-VL target models. | ||
|
|
||
| The forward signature is redefined to accept ``inputs_embeds`` as the | ||
| primary input (instead of ``input_ids``). This is critical because the | ||
| TorchScript tracer uses ``inspect.signature(model.forward)`` to determine | ||
| input parameter names and ordering. By removing ``input_ids`` from the | ||
| signature entirely, the traced/converted OpenVINO model will have | ||
| ``inputs_embeds`` as an explicit input (float32 3D tensor). | ||
|
|
||
| ``position_ids`` is kept as 3D ``[3, batch, seq]`` for MRoPE without | ||
| flattening. | ||
| """ | ||
|
|
||
| def __init__(self, config: LlamaConfig): | ||
| super().__init__(config) | ||
| # Replace standard rotary embedding with VLM-aware MRoPE embedding. | ||
| self.rotary_emb = Qwen3VLTextRotaryEmbedding(config=config) | ||
|
|
||
| def forward( | ||
| self, | ||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| hidden_states: Optional[torch.FloatTensor] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Cache] = None, | ||
| use_cache: Optional[bool] = None, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| **kwargs, | ||
| ) -> BaseModelOutputWithPast: | ||
| batch_size, seq_length, _ = hidden_states.shape | ||
| use_cache = use_cache if use_cache is not None else self.config.use_cache | ||
|
|
||
| if use_cache and past_key_values is None: | ||
| past_key_values = DynamicCache(config=self.config) | ||
|
|
||
| if cache_position is None: | ||
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
| cache_position = torch.arange( | ||
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | ||
| ) | ||
|
|
||
| if position_ids is None: | ||
| position_ids = cache_position.unsqueeze(0) | ||
|
|
||
| causal_mask = create_causal_mask( | ||
| config=self.config, | ||
| input_embeds=inputs_embeds, | ||
| attention_mask=attention_mask, | ||
| cache_position=cache_position, | ||
| past_key_values=past_key_values, | ||
| position_ids=position_ids, | ||
| ) | ||
|
|
||
| if hidden_states is None: | ||
| hidden_states = torch.zeros( | ||
| [batch_size, seq_length, self.hidden_size], | ||
| dtype=inputs_embeds.dtype, | ||
| device=inputs_embeds.device, | ||
| ) | ||
|
|
||
| inputs_embeds = inputs_embeds.to(hidden_states.dtype) | ||
| if hidden_states.shape[-1] != inputs_embeds.shape[-1]: | ||
| hidden_states = self.fc(hidden_states) | ||
|
|
||
| position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids) | ||
|
|
||
| hidden_states = self.midlayer( | ||
| input_emb=inputs_embeds, | ||
| hidden_states=hidden_states, | ||
| attention_mask=causal_mask, | ||
| position_embeddings=position_embeddings, | ||
| position_ids=position_ids, | ||
| past_key_values=past_key_values, | ||
| use_cache=True, | ||
| ) | ||
|
|
||
| hidden_states = self.norm(hidden_states) | ||
| return BaseModelOutputWithPast( | ||
| last_hidden_state=hidden_states, | ||
| past_key_values=past_key_values, | ||
| ) | ||
|
|
||
| class QwenVLEagle3ForCausalLM(LlamaEagle3ForCausalLM): | ||
| """ | ||
| Eagle-3 causal LM with Qwen3-VL MRoPE for VLM speculative decoding. | ||
|
|
||
| Uses QwenVLEagle3Model as the underlying model. The forward signature | ||
| is redefined to accept ``inputs_embeds`` instead of ``input_ids``, | ||
| ensuring the TorchScript tracer produces an OpenVINO model with | ||
| ``inputs_embeds`` as a named input parameter. | ||
| """ | ||
|
|
||
| def __init__(self, config): | ||
| super().__init__(config) | ||
| self.model = QwenVLEagle3Model(config) | ||
|
|
||
| def forward( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same |
||
| self, | ||
| inputs_embeds: Optional[torch.FloatTensor] = None, | ||
| hidden_states: Optional[torch.FloatTensor] = None, | ||
| attention_mask: Optional[torch.Tensor] = None, | ||
| position_ids: Optional[torch.LongTensor] = None, | ||
| past_key_values: Optional[Cache] = None, | ||
| use_cache: Optional[bool] = None, | ||
| cache_position: Optional[torch.LongTensor] = None, | ||
| logits_to_keep: Union[int, torch.Tensor] = 0, | ||
| **kwargs, | ||
| ) -> Eagle3Output: | ||
| outputs: BaseModelOutputWithPast = self.model( | ||
| inputs_embeds=inputs_embeds, | ||
| hidden_states=hidden_states, | ||
| attention_mask=attention_mask, | ||
| position_ids=position_ids, | ||
| past_key_values=past_key_values, | ||
| use_cache=use_cache, | ||
| cache_position=cache_position, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| hidden_states = outputs.last_hidden_state | ||
| slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep | ||
| logits = self.model.lm_head(hidden_states[:, slice_indices, :]) | ||
|
|
||
| d2t_out = self.identity(self.model.d2t) | ||
| return Eagle3Output( | ||
| logits=logits, | ||
| past_key_values=outputs.past_key_values, | ||
| hidden_states=outputs.hidden_states, | ||
| d2t=d2t_out, | ||
| ) | ||
|
|
||
|
|
||
| # Patched implementation of the gated delta rule in recurrent form. | ||
| # Adapted from: | ||
| # https://github.com/huggingface/transformers/blob/v4.57-release/src/transformers/models/qwen3_next/modeling_qwen3_next.py#L522 | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment linking to the original forward method and mentioning the differences please?