[Qwen3Next] preserve linear-attn-mask optimization under torch.compile/export#46148
Open
yuvrajsharma9981 wants to merge 2 commits into
Open
[Qwen3Next] preserve linear-attn-mask optimization under torch.compile/export#46148yuvrajsharma9981 wants to merge 2 commits into
yuvrajsharma9981 wants to merge 2 commits into
Conversation
…export `Qwen3_5TextModel._update_linear_attn_mask` short-circuits to None when `torch.all(attention_mask == 1)` is true. That check internally calls `.item()` on a 0-dim bool tensor to drive the Python `if`, which `torch.export` can't trace — it produces an unbacked symbolic int and fails to guard on `Eq(u0, 1)`. Any Qwen3.5-based model is therefore un-exportable via `torch.export.export`, which also blocks generating AOT `.pt2` packages with `torch._inductor.aoti_compile_and_package`. Hit this trying to AOT-compile Qwen3.5 for fast inference serving; the eager forward works, the export step crashes with `GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1)`. Smallest behavior-preserving fix I could come up with: keep the cached-forward shortcut (already export-compatible because it's a Python object-state check) and keep the no-padding fast-path in eager mode, but skip the data-dependent branch under `torch.compiler.is_compiling()` and pass the mask through unchanged. The downstream linear-attention layer handles an all-1s mask as a cheap no-op, so the exported graph behaves the same as the eager fast-path for no-padding inputs. Order: the cached check stays first so users exporting a decode-step graph still get the cached-skip optimization baked into the resulting graph.
modular_qwen3_next.py is the source-of-truth for _update_linear_attn_mask; Qwen3_5TextModel, Qwen3_5MoeModel, and OlmoHybridModel all inherit from Qwen3NextModel via modular conversion, so the same export-blocker exists in all four models. Fixing it once in the modular source + regenerating the modeling files lands the fix uniformly across the family. Also fixes the check_repository_consistency CI job that flagged the direct edit to the generated modeling_qwen3_5.py.
Contributor
|
[For maintainers] Suggested jobs to run (before merge) run-slow: olmo_hybrid, qwen3_5, qwen3_5_moe, qwen3_next |
Member
|
cc @ArthurZucker for text models! |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Hi,
`torch.export.export` fails on Qwen3Next-family models with `GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1)`. The crash traces to `Qwen3NextModel._update_linear_attn_mask`:
```python
if (past_key_values is not None and past_key_values.has_previous_state()) or (
attention_mask is not None and torch.all(attention_mask == 1)
):
linear_attn_mask = None
```
`torch.all(attention_mask == 1)` produces a 0-dim bool tensor, and Python's `if` does an implicit `.item()` on it — an unbacked symbolic int the exporter can't resolve. Net effect: any user wanting an AOT package (`torch._inductor.aoti_compile_and_package` → `.pt2`) for any model in this family is blocked at the export step.
I tripped on this trying to AOT compile Qwen3.5 for fast serving — the eager forward works, the export step crashes.
Scope
Fix lands at the modular source-of-truth, so the same patch propagates to all four models that inherit `Qwen3NextModel._update_linear_attn_mask`:
Fix
Smallest behavior-preserving thing I could come up with: keep the eager-mode fast-path identical, and skip the data-dependent branch only when `torch.compiler.is_compiling()` is true. The downstream linear-attention layer treats an all-1s mask as a cheap no-op, so the exported graph runs correctly for the no-padding case that the eager path was short-circuiting.
```python
def _update_linear_attn_mask(self, attention_mask, past_key_values):
linear_attn_mask = attention_mask
if past_key_values is not None and past_key_values.has_previous_state():
return None
if torch.compiler.is_compiling():
return linear_attn_mask
if attention_mask is not None and torch.all(attention_mask == 1):
linear_attn_mask = None
return linear_attn_mask
```
Two notes on the ordering:
Reproducer
Fails on v5.9.0 + torch 2.11:
```python
import torch
from transformers import AutoModelForCausalLM
m = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-4B", torch_dtype=torch.bfloat16)
m.eval()
class W(torch.nn.Module):
def init(s, m): super().init(); s.m = m
def forward(s, ids, mask):
return s.m(input_ids=ids, attention_mask=mask).logits
ids = torch.ones(2, 128, dtype=torch.long)
mask = torch.ones(2, 128, dtype=torch.long)
torch.export.export(W(m), (ids, mask), dynamic_shapes={
"ids": {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
"mask": {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
})
```
After the change, verified locally with a forked-source install:
Commits
Happy to add tests under `tests/models/qwen3_next/` (and the inheriting models) if that's the preferred shape — held off pending guidance on existing export-compat coverage conventions.
Thanks!