Skip to content

Commit 6fe4681

Browse files
authored
[fix]Add logits_to_keep and shift_labels support for Qwen3-VL and Qwen3-VL-MoE (#1181)
# Summary This PR adds `logits_to_keep` and `shift_labels` support for both `Qwen3-VL` and `Qwen3-VL-MoE` in the Liger-patched forward path. The change aligns the patched implementation with the expected Hugging Face interface and enables selective logits materialization for long-context inference. # Testing Done - `make test` - not fully green - observed existing failures in `GRPO`, `fused_neighborhood_attention`, and `gemma3` monkey patch tests - `make test-convergence` - not fully green - observed failure in `test/convergence/bf16/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_llama4-...]` - `make checkstyle` - passed Known limitation: - The failed `make test` / `make test-convergence` cases above do not directly exercise the `Qwen3-VL` or `Qwen3-VL-MoE` `logits_to_keep` / `shift_labels` change in this PR
1 parent fcaae50 commit 6fe4681

2 files changed

Lines changed: 34 additions & 9 deletions

File tree

src/liger_kernel/transformers/model/qwen3_vl.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def lce_forward(
3333
mm_token_type_ids: Optional[torch.IntTensor] = None,
3434
cache_position: Optional[torch.LongTensor] = None,
3535
second_per_grid_ts: Optional[torch.Tensor] = None,
36+
logits_to_keep: Union[int, torch.Tensor] = 0,
3637
skip_logits: Optional[bool] = None,
3738
**kwargs,
3839
) -> Union[Tuple, LigerQwen3VLCausalLMOutputWithPast]:
@@ -53,6 +54,9 @@ def lce_forward(
5354
The rope index difference between sequence length and multimodal rope.
5455
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
5556
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
57+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
58+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
59+
input tokens. If a `torch.Tensor`, it must contain the sequence indices to keep.
5660
Example:
5761
```python
5862
>>> from PIL import Image
@@ -106,6 +110,8 @@ def lce_forward(
106110
)
107111

108112
hidden_states = outputs[0]
113+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
114+
kept_hidden_states = hidden_states[:, slice_indices, :]
109115

110116
shift_labels = kwargs.pop("shift_labels", None)
111117
loss = None
@@ -121,7 +127,7 @@ def lce_forward(
121127

122128
if skip_logits:
123129
result = LigerForCausalLMLoss(
124-
hidden_states=hidden_states,
130+
hidden_states=kept_hidden_states,
125131
lm_head_weight=self.lm_head.weight,
126132
labels=labels,
127133
shift_labels=shift_labels,
@@ -130,11 +136,17 @@ def lce_forward(
130136
)
131137
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
132138
else:
133-
logits = self.lm_head(hidden_states)
139+
logits = self.lm_head(kept_hidden_states)
134140

135141
loss = None
136-
if labels is not None:
137-
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
142+
if labels is not None or shift_labels is not None:
143+
loss = self.loss_function(
144+
logits=logits,
145+
labels=labels,
146+
shift_labels=shift_labels,
147+
vocab_size=self.config.text_config.vocab_size,
148+
**kwargs,
149+
)
138150

139151
if not return_dict:
140152
output = (logits,) + outputs[1:]

src/liger_kernel/transformers/model/qwen3_vl_moe.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,16 @@ def lce_forward(
3434
mm_token_type_ids: Optional[torch.IntTensor] = None,
3535
cache_position: Optional[torch.LongTensor] = None,
3636
second_per_grid_ts: Optional[torch.Tensor] = None,
37+
logits_to_keep: Union[int, torch.Tensor] = 0,
3738
skip_logits: Optional[bool] = None,
3839
**kwargs,
3940
) -> Union[Tuple, LigerQwen3VLMoeCausalLMOutputWithPast]:
4041
"""
4142
Qwen3-VL-MoE forward with fused linear cross entropy support mirroring Qwen3-VL behaviour.
43+
44+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
45+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
46+
input tokens. If a `torch.Tensor`, it must contain the sequence indices to keep.
4247
"""
4348

4449
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -68,6 +73,8 @@ def lce_forward(
6873
)
6974

7075
hidden_states = outputs[0]
76+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
77+
kept_hidden_states = hidden_states[:, slice_indices, :]
7178

7279
shift_labels = kwargs.pop("shift_labels", None)
7380
loss = None
@@ -83,7 +90,7 @@ def lce_forward(
8390

8491
if skip_logits:
8592
result = LigerForCausalLMLoss(
86-
hidden_states=hidden_states,
93+
hidden_states=kept_hidden_states,
8794
lm_head_weight=self.lm_head.weight,
8895
labels=labels,
8996
shift_labels=shift_labels,
@@ -92,10 +99,16 @@ def lce_forward(
9299
)
93100
loss, _, token_accuracy, predicted_tokens = unpack_cross_entropy_result(result)
94101
else:
95-
logits = self.lm_head(hidden_states)
96-
97-
if labels is not None:
98-
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
102+
logits = self.lm_head(kept_hidden_states)
103+
104+
if labels is not None or shift_labels is not None:
105+
loss = self.loss_function(
106+
logits=logits,
107+
labels=labels,
108+
shift_labels=shift_labels,
109+
vocab_size=self.config.text_config.vocab_size,
110+
**kwargs,
111+
)
99112

100113
# Compute auxiliary load-balancing loss for MoE when requested
101114
aux_loss = None

0 commit comments

Comments
 (0)