Skip to content

[Qwen3Next] preserve linear-attn-mask optimization under torch.compile/export#46148

Open
yuvrajsharma9981 wants to merge 2 commits into
huggingface:mainfrom
yuvrajsharma9981:yuvi/qwen3_5-export-compat
Open

[Qwen3Next] preserve linear-attn-mask optimization under torch.compile/export#46148
yuvrajsharma9981 wants to merge 2 commits into
huggingface:mainfrom
yuvrajsharma9981:yuvi/qwen3_5-export-compat

Conversation

@yuvrajsharma9981
Copy link
Copy Markdown

@yuvrajsharma9981 yuvrajsharma9981 commented May 21, 2026

Hi,

`torch.export.export` fails on Qwen3Next-family models with `GuardOnDataDependentSymNode: Could not guard on data-dependent expression Eq(u0, 1)`. The crash traces to `Qwen3NextModel._update_linear_attn_mask`:

```python
if (past_key_values is not None and past_key_values.has_previous_state()) or (
attention_mask is not None and torch.all(attention_mask == 1)
):
linear_attn_mask = None
```

`torch.all(attention_mask == 1)` produces a 0-dim bool tensor, and Python's `if` does an implicit `.item()` on it — an unbacked symbolic int the exporter can't resolve. Net effect: any user wanting an AOT package (`torch._inductor.aoti_compile_and_package` → `.pt2`) for any model in this family is blocked at the export step.

I tripped on this trying to AOT compile Qwen3.5 for fast serving — the eager forward works, the export step crashes.

Scope

Fix lands at the modular source-of-truth, so the same patch propagates to all four models that inherit `Qwen3NextModel._update_linear_attn_mask`:

  • Qwen3Next (direct)
  • Qwen3.5 (`Qwen3_5TextModel(Qwen3NextModel)`)
  • Qwen3.5-MoE (`Qwen3_5MoeTextModel` via the same lineage)
  • OLMo Hybrid (`OlmoHybridModel(Qwen3NextModel)`)

Fix

Smallest behavior-preserving thing I could come up with: keep the eager-mode fast-path identical, and skip the data-dependent branch only when `torch.compiler.is_compiling()` is true. The downstream linear-attention layer treats an all-1s mask as a cheap no-op, so the exported graph runs correctly for the no-padding case that the eager path was short-circuiting.

```python
def _update_linear_attn_mask(self, attention_mask, past_key_values):
linear_attn_mask = attention_mask
if past_key_values is not None and past_key_values.has_previous_state():
return None
if torch.compiler.is_compiling():
return linear_attn_mask
if attention_mask is not None and torch.all(attention_mask == 1):
linear_attn_mask = None
return linear_attn_mask
```

Two notes on the ordering:

  1. The cached-forward check stays first so users exporting a decode-step graph still get the cached-skip optimization baked into the resulting graph — that branch is already export-compatible (Python object state, not a tensor `.item()`).
  2. `torch.compiler.is_compiling()` is the public PyTorch idiom for "behave differently under trace"; runtime behavior for everyone not exporting is byte-identical to before.

Reproducer

Fails on v5.9.0 + torch 2.11:

```python
import torch
from transformers import AutoModelForCausalLM

m = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3.5-4B", torch_dtype=torch.bfloat16)
m.eval()

class W(torch.nn.Module):
def init(s, m): super().init(); s.m = m
def forward(s, ids, mask):
return s.m(input_ids=ids, attention_mask=mask).logits

ids = torch.ones(2, 128, dtype=torch.long)
mask = torch.ones(2, 128, dtype=torch.long)

torch.export.export(W(m), (ids, mask), dynamic_shapes={
"ids": {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
"mask": {0: torch.export.Dim.AUTO, 1: torch.export.Dim.AUTO},
})
```

After the change, verified locally with a forked-source install:

  • eager runtime with all-1s mask still returns `None` (existing optimization preserved)
  • `torch.export.export(...)` succeeds and traces a clean graph

Commits

  • First commit edited the generated `modeling_qwen3_5.py` directly — CI correctly flagged this via `check_repository_consistency`.
  • Second commit moves the fix to `modular_qwen3_next.py` (the source-of-truth) and regenerates the affected `modeling_*.py` files via `make fix-repo`. Both checks pass locally now.

Happy to add tests under `tests/models/qwen3_next/` (and the inheriting models) if that's the preferred shape — held off pending guidance on existing export-compat coverage conventions.

Thanks!

…export

`Qwen3_5TextModel._update_linear_attn_mask` short-circuits to None when
`torch.all(attention_mask == 1)` is true. That check internally calls
`.item()` on a 0-dim bool tensor to drive the Python `if`, which
`torch.export` can't trace — it produces an unbacked symbolic int and
fails to guard on `Eq(u0, 1)`. Any Qwen3.5-based model is therefore
un-exportable via `torch.export.export`, which also blocks generating
AOT `.pt2` packages with `torch._inductor.aoti_compile_and_package`.

Hit this trying to AOT-compile Qwen3.5 for fast inference serving;
the eager forward works, the export step crashes with
`GuardOnDataDependentSymNode: Could not guard on data-dependent
expression Eq(u0, 1)`.

Smallest behavior-preserving fix I could come up with: keep the
cached-forward shortcut (already export-compatible because it's a
Python object-state check) and keep the no-padding fast-path in eager
mode, but skip the data-dependent branch under
`torch.compiler.is_compiling()` and pass the mask through unchanged.
The downstream linear-attention layer handles an all-1s mask as a
cheap no-op, so the exported graph behaves the same as the eager
fast-path for no-padding inputs.

Order: the cached check stays first so users exporting a decode-step
graph still get the cached-skip optimization baked into the resulting
graph.
modular_qwen3_next.py is the source-of-truth for _update_linear_attn_mask;
Qwen3_5TextModel, Qwen3_5MoeModel, and OlmoHybridModel all inherit from
Qwen3NextModel via modular conversion, so the same export-blocker exists
in all four models. Fixing it once in the modular source + regenerating
the modeling files lands the fix uniformly across the family.

Also fixes the check_repository_consistency CI job that flagged the
direct edit to the generated modeling_qwen3_5.py.
@yuvrajsharma9981 yuvrajsharma9981 changed the title [Qwen3.5] preserve linear-attn-mask optimization under torch.compile/export [Qwen3Next] preserve linear-attn-mask optimization under torch.compile/export May 21, 2026
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: olmo_hybrid, qwen3_5, qwen3_5_moe, qwen3_next

@Rocketknight1
Copy link
Copy Markdown
Member

cc @ArthurZucker for text models!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants