Skip to content

Commit 9d413ac

Browse files
authored
Add Qwen3.5 export support for 0.8B/2B/4B (#17800)
## Summary - add initial Qwen3.5 model integration under examples/models/qwen3_5 - add HF->Meta checkpoint conversion for Qwen3.5 naming/layout - wire Qwen3.5 model types into llama export config and registry - add 0.8B, 2B, and 4B configs plus xnnpack fp32 export config - add unit tests for Qwen3.5 attention path and weight conversion ## Test Plan - PYTHONPATH=src pytest -q examples/models/llama/tests/test_qwen3_5_attention.py examples/models/qwen3_5/tests/test_convert_weights.py - local smoke export succeeded for Qwen3.5-0.8B: /tmp/qwen35_test/qwen3_5_0_8b_fp32_smoke.pte cc @mergennachin @iseeyuan @lucylq @helunwencser @tarun292 @kimishpatel @jackzhxng
1 parent 0af0162 commit 9d413ac

20 files changed

Lines changed: 1048 additions & 22 deletions

examples/models/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ fbcode_target(_kind = python_library,
2929
"//executorch/examples/models/gemma3:gemma3", # @manual
3030
"//executorch/examples/models/qwen2_5:qwen2_5", # @manual
3131
"//executorch/examples/models/qwen3:qwen3", # @manual
32+
"//executorch/examples/models/qwen3_5:qwen3_5", # @manual
3233
"//executorch/examples/models/phi_4_mini:phi_4_mini", # @manual
3334
"//executorch/examples/models/smollm2:smollm2", # @manual
3435
"//executorch/examples/models/smollm3:smollm3", # @manual

examples/models/llama/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,17 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from .model import Llama2Model
7+
from typing import TYPE_CHECKING
88

9-
__all__ = [
10-
Llama2Model,
11-
]
9+
if TYPE_CHECKING:
10+
from .model import Llama2Model
11+
12+
__all__ = ["Llama2Model"]
13+
14+
15+
def __getattr__(name):
16+
if name == "Llama2Model":
17+
from .model import Llama2Model
18+
19+
return Llama2Model
20+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

examples/models/llama/attention.py

Lines changed: 250 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn.functional as F
88
from executorch.examples.models.llama.lora import LoRALinear
99
from executorch.examples.models.llama.model_args import ModelArgs
10-
from executorch.examples.models.llama.norm import RMSNorm
10+
from executorch.examples.models.llama.norm import RMSNorm, RMSNormGated
1111
from executorch.examples.models.llama.rope import Rope
1212

1313

@@ -347,27 +347,35 @@ def __init__(
347347
self.attention_qkv_bias = args.attention_qkv_bias
348348
self.use_qk_norm = args.use_qk_norm
349349
self.qk_norm_before_rope = args.qk_norm_before_rope
350+
self.use_q_gate = args.use_q_gate
350351
self.enable_dynamic_shape = args.enable_dynamic_shape
352+
q_out_dim = self.n_heads * self.head_dim * (2 if self.use_q_gate else 1)
351353

352354
if self.use_qk_norm:
353355
q_norm_dim = self.head_dim
354356
k_norm_dim = self.head_dim
355-
self.q_norm_fn = RMSNorm(q_norm_dim, eps=args.norm_eps)
356-
self.k_norm_fn = RMSNorm(k_norm_dim, eps=args.norm_eps)
357+
self.q_norm_fn = RMSNorm(
358+
q_norm_dim,
359+
eps=args.norm_eps,
360+
add_unit_offset=args.rms_norm_add_unit_offset,
361+
)
362+
self.k_norm_fn = RMSNorm(
363+
k_norm_dim,
364+
eps=args.norm_eps,
365+
add_unit_offset=args.rms_norm_add_unit_offset,
366+
)
357367

358368
self.wq = (
359369
LoRALinear(
360370
in_dim=args.dim,
361-
out_dim=args.n_heads * args.head_dim,
371+
out_dim=q_out_dim,
362372
rank=args.r,
363373
alpha=args.lora_alpha,
364374
dropout=0.0,
365375
use_bias=args.attention_qkv_bias,
366376
)
367377
if args.target_modules is not None and "q_proj" in args.target_modules
368-
else nn.Linear(
369-
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
370-
)
378+
else nn.Linear(self.dim, q_out_dim, bias=self.attention_qkv_bias)
371379
)
372380
self.wk = (
373381
LoRALinear(
@@ -452,10 +460,17 @@ def forward(
452460
input_pos = kwargs.get("input_pos")
453461
bsz, seqlen, _ = x.shape
454462

455-
# QKV
456-
q, k, v = self.wq(x), self.wk(x), self.wv(x)
457-
# We need view_copy elimination
458-
q = q.view(bsz, seqlen, self.n_local_heads, self.head_dim)
463+
if self.use_q_gate:
464+
q_and_gate = self.wq(x).view(
465+
bsz, seqlen, self.n_local_heads, self.head_dim * 2
466+
)
467+
q, gate = torch.chunk(q_and_gate, 2, dim=-1)
468+
gate = gate.reshape(bsz, seqlen, -1)
469+
else:
470+
q = self.wq(x).view(bsz, seqlen, self.n_local_heads, self.head_dim)
471+
gate = None
472+
473+
k, v = self.wk(x), self.wv(x)
459474
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
460475
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
461476

@@ -492,6 +507,8 @@ def forward(
492507
input_pos[0].item(), seqlen
493508
)
494509
output = self.SDPA(input_pos, q, k, v, bsz, seqlen, attn_mask)
510+
if gate is not None:
511+
output = output * torch.sigmoid(gate)
495512
return self.wo(output), None
496513

497514
# grouped multiquery attention: expand out keys and values
@@ -505,12 +522,234 @@ def forward(
505522
output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
506523

507524
output = output.transpose(1, 2).reshape(bsz, seqlen, -1)
525+
if gate is not None:
526+
output = output * torch.sigmoid(gate)
508527

509528
output = self.wo(output)
510529

511530
return output, None
512531

513532

533+
def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:
534+
inv_norm = torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps)
535+
return x * inv_norm
536+
537+
538+
@register_attention("gated_deltanet")
539+
class AttentionGatedDeltaNet(Attention):
540+
"""Qwen3.5 linear-attention (Gated DeltaNet) block with internal state."""
541+
542+
def __init__(
543+
self,
544+
args: ModelArgs,
545+
layer_id: int,
546+
rope: Rope,
547+
**_kwargs: Any,
548+
):
549+
super().__init__()
550+
del rope # DeltaNet layers do not use RoPE.
551+
552+
self.hidden_size = args.dim
553+
self.max_batch_size = args.max_batch_size
554+
self.layer_id = layer_id
555+
556+
assert args.linear_num_key_heads is not None
557+
assert args.linear_num_value_heads is not None
558+
assert args.linear_key_head_dim is not None
559+
assert args.linear_value_head_dim is not None
560+
561+
self.num_k_heads = args.linear_num_key_heads
562+
self.num_v_heads = args.linear_num_value_heads
563+
self.head_k_dim = args.linear_key_head_dim
564+
self.head_v_dim = args.linear_value_head_dim
565+
self.key_dim = self.head_k_dim * self.num_k_heads
566+
self.value_dim = self.head_v_dim * self.num_v_heads
567+
self.conv_kernel_size = args.linear_conv_kernel_dim
568+
569+
assert (
570+
self.num_v_heads % self.num_k_heads == 0
571+
), "linear_num_value_heads must be divisible by linear_num_key_heads."
572+
self.head_repeat = self.num_v_heads // self.num_k_heads
573+
574+
self.conv_dim = self.key_dim * 2 + self.value_dim
575+
self.in_proj_qkv = nn.Linear(self.hidden_size, self.conv_dim, bias=False)
576+
self.in_proj_z = nn.Linear(self.hidden_size, self.value_dim, bias=False)
577+
self.in_proj_b = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
578+
self.in_proj_a = nn.Linear(self.hidden_size, self.num_v_heads, bias=False)
579+
580+
self.conv1d = nn.Conv1d(
581+
in_channels=self.conv_dim,
582+
out_channels=self.conv_dim,
583+
kernel_size=self.conv_kernel_size,
584+
groups=self.conv_dim,
585+
bias=False,
586+
padding=0,
587+
)
588+
589+
self.dt_bias = nn.Parameter(torch.ones(self.num_v_heads))
590+
A = torch.empty(self.num_v_heads).uniform_(0, 16)
591+
self.A_log = nn.Parameter(torch.log(A))
592+
self.norm = RMSNormGated(self.head_v_dim, eps=args.norm_eps)
593+
self.out_proj = nn.Linear(self.value_dim, self.hidden_size, bias=False)
594+
595+
self.register_buffer(
596+
"conv_state",
597+
torch.zeros(
598+
self.max_batch_size,
599+
self.conv_dim,
600+
self.conv_kernel_size,
601+
dtype=torch.float32,
602+
device="cpu",
603+
),
604+
)
605+
self.register_buffer(
606+
"recurrent_state",
607+
torch.zeros(
608+
self.max_batch_size,
609+
self.num_v_heads,
610+
self.head_k_dim,
611+
self.head_v_dim,
612+
dtype=torch.float32,
613+
device="cpu",
614+
),
615+
)
616+
617+
def _maybe_reset_state(
618+
self, input_pos: Optional[torch.Tensor], batch_size: int
619+
) -> None:
620+
if input_pos is None:
621+
self.conv_state[:batch_size].zero_()
622+
self.recurrent_state[:batch_size].zero_()
623+
return
624+
reset = (input_pos[0] == 0).to(self.conv_state.dtype)
625+
keep = 1.0 - reset
626+
self.conv_state[:batch_size].mul_(keep)
627+
self.recurrent_state[:batch_size].mul_(keep)
628+
629+
def _apply_causal_conv(self, mixed_qkv: torch.Tensor) -> torch.Tensor:
630+
# mixed_qkv: (batch, seq_len, conv_dim)
631+
batch_size, seq_len, _ = mixed_qkv.shape
632+
mixed_qkv = mixed_qkv.transpose(1, 2)
633+
state_len = self.conv_state.shape[-1]
634+
hidden_states_new = torch.cat([self.conv_state[:batch_size], mixed_qkv], dim=-1)
635+
new_conv_state = hidden_states_new[:, :, -state_len:]
636+
with torch.no_grad():
637+
self.conv_state[:batch_size].copy_(new_conv_state.to(self.conv_state.dtype))
638+
out = F.conv1d(
639+
hidden_states_new,
640+
self.conv1d.weight,
641+
self.conv1d.bias,
642+
padding=0,
643+
groups=self.conv_dim,
644+
)
645+
out = F.silu(out[:, :, -seq_len:]).to(mixed_qkv.dtype)
646+
return out.transpose(1, 2).contiguous()
647+
648+
def _recurrent_gated_delta_rule(
649+
self,
650+
query: torch.Tensor,
651+
key: torch.Tensor,
652+
value: torch.Tensor,
653+
g: torch.Tensor,
654+
beta: torch.Tensor,
655+
) -> torch.Tensor:
656+
# query/key/value: (batch, seq_len, num_heads, head_dim)
657+
# g/beta: (batch, seq_len, num_heads)
658+
initial_dtype = query.dtype
659+
query = _l2norm(query, dim=-1, eps=1e-6)
660+
key = _l2norm(key, dim=-1, eps=1e-6)
661+
query, key, value, beta, g = [
662+
x.transpose(1, 2).contiguous().to(torch.float32)
663+
for x in (query, key, value, beta, g)
664+
]
665+
666+
batch_size, num_heads, sequence_length, k_head_dim = key.shape
667+
v_head_dim = value.shape[-1]
668+
scale = 1.0 / (query.shape[-1] ** 0.5)
669+
query = query * scale
670+
671+
core_attn_out = torch.zeros(
672+
batch_size,
673+
num_heads,
674+
sequence_length,
675+
v_head_dim,
676+
device=value.device,
677+
dtype=value.dtype,
678+
)
679+
last_recurrent_state = self.recurrent_state[:batch_size].to(value.dtype)
680+
681+
for i in range(sequence_length):
682+
q_t = query[:, :, i]
683+
k_t = key[:, :, i]
684+
v_t = value[:, :, i]
685+
g_t = g[:, :, i].exp().unsqueeze(-1).unsqueeze(-1)
686+
beta_t = beta[:, :, i].unsqueeze(-1)
687+
688+
last_recurrent_state = last_recurrent_state * g_t
689+
kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2)
690+
delta = (v_t - kv_mem) * beta_t
691+
last_recurrent_state = last_recurrent_state + k_t.unsqueeze(
692+
-1
693+
) * delta.unsqueeze(-2)
694+
core_attn_out[:, :, i] = (last_recurrent_state * q_t.unsqueeze(-1)).sum(
695+
dim=-2
696+
)
697+
698+
with torch.no_grad():
699+
self.recurrent_state[:batch_size].copy_(
700+
last_recurrent_state.to(self.recurrent_state.dtype)
701+
)
702+
703+
return core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
704+
705+
def forward(
706+
self,
707+
x: torch.Tensor,
708+
freqs_cos: torch.Tensor,
709+
freqs_sin: torch.Tensor,
710+
**kwargs: ForwardOptions,
711+
) -> Tuple[torch.Tensor, Optional[Any]]:
712+
del freqs_cos
713+
del freqs_sin
714+
input_pos = kwargs.get("input_pos")
715+
batch_size, seq_len, _ = x.shape
716+
assert (
717+
batch_size <= self.max_batch_size
718+
), f"batch_size ({batch_size}) exceeds max_batch_size ({self.max_batch_size})"
719+
720+
self._maybe_reset_state(input_pos, batch_size)
721+
722+
mixed_qkv = self.in_proj_qkv(x)
723+
z = self.in_proj_z(x).reshape(batch_size, seq_len, -1, self.head_v_dim)
724+
b = self.in_proj_b(x)
725+
a = self.in_proj_a(x)
726+
727+
mixed_qkv = self._apply_causal_conv(mixed_qkv)
728+
query, key, value = torch.split(
729+
mixed_qkv,
730+
[self.key_dim, self.key_dim, self.value_dim],
731+
dim=-1,
732+
)
733+
query = query.reshape(batch_size, seq_len, -1, self.head_k_dim)
734+
key = key.reshape(batch_size, seq_len, -1, self.head_k_dim)
735+
value = value.reshape(batch_size, seq_len, -1, self.head_v_dim)
736+
737+
if self.head_repeat > 1:
738+
query = query.repeat_interleave(self.head_repeat, dim=2)
739+
key = key.repeat_interleave(self.head_repeat, dim=2)
740+
741+
beta = b.sigmoid()
742+
g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
743+
core_attn_out = self._recurrent_gated_delta_rule(query, key, value, g, beta)
744+
745+
core_attn_out = core_attn_out.reshape(-1, self.head_v_dim)
746+
z = z.reshape(-1, self.head_v_dim)
747+
core_attn_out = self.norm(core_attn_out, z)
748+
core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1)
749+
750+
return self.out_proj(core_attn_out), None
751+
752+
514753
@register_attention("skip")
515754
class AttentionSkip(Attention):
516755
def __init__(self, *args, **kwargs):

examples/models/llama/export_llama_lib.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@
107107
"qwen3_0_6b",
108108
"qwen3_1_7b",
109109
"qwen3_4b",
110+
"qwen3_5_0_8b",
111+
"qwen3_5_2b",
112+
"qwen3_5_4b",
110113
"phi_4_mini",
111114
"smollm2",
112115
"lfm2_350m", # hybrid
@@ -124,6 +127,9 @@
124127
"qwen3_0_6b": "Qwen/Qwen3-0.6B",
125128
"qwen3_1_7b": "Qwen/Qwen3-1.7B",
126129
"qwen3_4b": "Qwen/Qwen3-4B",
130+
"qwen3_5_0_8b": "Qwen/Qwen3.5-0.8B",
131+
"qwen3_5_2b": "Qwen/Qwen3.5-2B",
132+
"qwen3_5_4b": "Qwen/Qwen3.5-4B",
127133
"lfm2_350m": "LiquidAI/LFM2-350M",
128134
"lfm2_700m": "LiquidAI/LFM2-700M",
129135
"lfm2_1_2b": "LiquidAI/LFM2-1.2B",
@@ -622,7 +628,7 @@ def canonical_path(path: Union[str, Path], *, dir: bool = False) -> str:
622628
return return_val
623629

624630

625-
def export_llama(
631+
def export_llama( # noqa: C901
626632
export_options: Union[argparse.Namespace, LlmConfig, DictConfig],
627633
) -> str:
628634
if isinstance(export_options, argparse.Namespace):
@@ -645,6 +651,8 @@ def export_llama(
645651
repo_id = HUGGING_FACE_REPO_IDS[model_name]
646652
if model_name.startswith("qwen2_5"):
647653
from executorch.examples.models.qwen2_5 import convert_weights
654+
elif model_name.startswith("qwen3_5"):
655+
from executorch.examples.models.qwen3_5 import convert_weights
648656
elif model_name.startswith("qwen3"):
649657
from executorch.examples.models.qwen3 import convert_weights
650658
elif model_name == "phi_4_mini":

0 commit comments

Comments
 (0)