77import torch .nn .functional as F
88from executorch .examples .models .llama .lora import LoRALinear
99from executorch .examples .models .llama .model_args import ModelArgs
10- from executorch .examples .models .llama .norm import RMSNorm
10+ from executorch .examples .models .llama .norm import RMSNorm , RMSNormGated
1111from executorch .examples .models .llama .rope import Rope
1212
1313
@@ -347,27 +347,35 @@ def __init__(
347347 self .attention_qkv_bias = args .attention_qkv_bias
348348 self .use_qk_norm = args .use_qk_norm
349349 self .qk_norm_before_rope = args .qk_norm_before_rope
350+ self .use_q_gate = args .use_q_gate
350351 self .enable_dynamic_shape = args .enable_dynamic_shape
352+ q_out_dim = self .n_heads * self .head_dim * (2 if self .use_q_gate else 1 )
351353
352354 if self .use_qk_norm :
353355 q_norm_dim = self .head_dim
354356 k_norm_dim = self .head_dim
355- self .q_norm_fn = RMSNorm (q_norm_dim , eps = args .norm_eps )
356- self .k_norm_fn = RMSNorm (k_norm_dim , eps = args .norm_eps )
357+ self .q_norm_fn = RMSNorm (
358+ q_norm_dim ,
359+ eps = args .norm_eps ,
360+ add_unit_offset = args .rms_norm_add_unit_offset ,
361+ )
362+ self .k_norm_fn = RMSNorm (
363+ k_norm_dim ,
364+ eps = args .norm_eps ,
365+ add_unit_offset = args .rms_norm_add_unit_offset ,
366+ )
357367
358368 self .wq = (
359369 LoRALinear (
360370 in_dim = args .dim ,
361- out_dim = args . n_heads * args . head_dim ,
371+ out_dim = q_out_dim ,
362372 rank = args .r ,
363373 alpha = args .lora_alpha ,
364374 dropout = 0.0 ,
365375 use_bias = args .attention_qkv_bias ,
366376 )
367377 if args .target_modules is not None and "q_proj" in args .target_modules
368- else nn .Linear (
369- self .dim , self .n_heads * self .head_dim , bias = self .attention_qkv_bias
370- )
378+ else nn .Linear (self .dim , q_out_dim , bias = self .attention_qkv_bias )
371379 )
372380 self .wk = (
373381 LoRALinear (
@@ -452,10 +460,17 @@ def forward(
452460 input_pos = kwargs .get ("input_pos" )
453461 bsz , seqlen , _ = x .shape
454462
455- # QKV
456- q , k , v = self .wq (x ), self .wk (x ), self .wv (x )
457- # We need view_copy elimination
458- q = q .view (bsz , seqlen , self .n_local_heads , self .head_dim )
463+ if self .use_q_gate :
464+ q_and_gate = self .wq (x ).view (
465+ bsz , seqlen , self .n_local_heads , self .head_dim * 2
466+ )
467+ q , gate = torch .chunk (q_and_gate , 2 , dim = - 1 )
468+ gate = gate .reshape (bsz , seqlen , - 1 )
469+ else :
470+ q = self .wq (x ).view (bsz , seqlen , self .n_local_heads , self .head_dim )
471+ gate = None
472+
473+ k , v = self .wk (x ), self .wv (x )
459474 k = k .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
460475 v = v .view (bsz , seqlen , self .n_local_kv_heads , self .head_dim )
461476
@@ -492,6 +507,8 @@ def forward(
492507 input_pos [0 ].item (), seqlen
493508 )
494509 output = self .SDPA (input_pos , q , k , v , bsz , seqlen , attn_mask )
510+ if gate is not None :
511+ output = output * torch .sigmoid (gate )
495512 return self .wo (output ), None
496513
497514 # grouped multiquery attention: expand out keys and values
@@ -505,12 +522,234 @@ def forward(
505522 output = F .scaled_dot_product_attention (q , k , v , attn_mask = mask , dropout_p = 0.0 )
506523
507524 output = output .transpose (1 , 2 ).reshape (bsz , seqlen , - 1 )
525+ if gate is not None :
526+ output = output * torch .sigmoid (gate )
508527
509528 output = self .wo (output )
510529
511530 return output , None
512531
513532
533+ def _l2norm (x : torch .Tensor , dim : int = - 1 , eps : float = 1e-6 ) -> torch .Tensor :
534+ inv_norm = torch .rsqrt ((x * x ).sum (dim = dim , keepdim = True ) + eps )
535+ return x * inv_norm
536+
537+
538+ @register_attention ("gated_deltanet" )
539+ class AttentionGatedDeltaNet (Attention ):
540+ """Qwen3.5 linear-attention (Gated DeltaNet) block with internal state."""
541+
542+ def __init__ (
543+ self ,
544+ args : ModelArgs ,
545+ layer_id : int ,
546+ rope : Rope ,
547+ ** _kwargs : Any ,
548+ ):
549+ super ().__init__ ()
550+ del rope # DeltaNet layers do not use RoPE.
551+
552+ self .hidden_size = args .dim
553+ self .max_batch_size = args .max_batch_size
554+ self .layer_id = layer_id
555+
556+ assert args .linear_num_key_heads is not None
557+ assert args .linear_num_value_heads is not None
558+ assert args .linear_key_head_dim is not None
559+ assert args .linear_value_head_dim is not None
560+
561+ self .num_k_heads = args .linear_num_key_heads
562+ self .num_v_heads = args .linear_num_value_heads
563+ self .head_k_dim = args .linear_key_head_dim
564+ self .head_v_dim = args .linear_value_head_dim
565+ self .key_dim = self .head_k_dim * self .num_k_heads
566+ self .value_dim = self .head_v_dim * self .num_v_heads
567+ self .conv_kernel_size = args .linear_conv_kernel_dim
568+
569+ assert (
570+ self .num_v_heads % self .num_k_heads == 0
571+ ), "linear_num_value_heads must be divisible by linear_num_key_heads."
572+ self .head_repeat = self .num_v_heads // self .num_k_heads
573+
574+ self .conv_dim = self .key_dim * 2 + self .value_dim
575+ self .in_proj_qkv = nn .Linear (self .hidden_size , self .conv_dim , bias = False )
576+ self .in_proj_z = nn .Linear (self .hidden_size , self .value_dim , bias = False )
577+ self .in_proj_b = nn .Linear (self .hidden_size , self .num_v_heads , bias = False )
578+ self .in_proj_a = nn .Linear (self .hidden_size , self .num_v_heads , bias = False )
579+
580+ self .conv1d = nn .Conv1d (
581+ in_channels = self .conv_dim ,
582+ out_channels = self .conv_dim ,
583+ kernel_size = self .conv_kernel_size ,
584+ groups = self .conv_dim ,
585+ bias = False ,
586+ padding = 0 ,
587+ )
588+
589+ self .dt_bias = nn .Parameter (torch .ones (self .num_v_heads ))
590+ A = torch .empty (self .num_v_heads ).uniform_ (0 , 16 )
591+ self .A_log = nn .Parameter (torch .log (A ))
592+ self .norm = RMSNormGated (self .head_v_dim , eps = args .norm_eps )
593+ self .out_proj = nn .Linear (self .value_dim , self .hidden_size , bias = False )
594+
595+ self .register_buffer (
596+ "conv_state" ,
597+ torch .zeros (
598+ self .max_batch_size ,
599+ self .conv_dim ,
600+ self .conv_kernel_size ,
601+ dtype = torch .float32 ,
602+ device = "cpu" ,
603+ ),
604+ )
605+ self .register_buffer (
606+ "recurrent_state" ,
607+ torch .zeros (
608+ self .max_batch_size ,
609+ self .num_v_heads ,
610+ self .head_k_dim ,
611+ self .head_v_dim ,
612+ dtype = torch .float32 ,
613+ device = "cpu" ,
614+ ),
615+ )
616+
617+ def _maybe_reset_state (
618+ self , input_pos : Optional [torch .Tensor ], batch_size : int
619+ ) -> None :
620+ if input_pos is None :
621+ self .conv_state [:batch_size ].zero_ ()
622+ self .recurrent_state [:batch_size ].zero_ ()
623+ return
624+ reset = (input_pos [0 ] == 0 ).to (self .conv_state .dtype )
625+ keep = 1.0 - reset
626+ self .conv_state [:batch_size ].mul_ (keep )
627+ self .recurrent_state [:batch_size ].mul_ (keep )
628+
629+ def _apply_causal_conv (self , mixed_qkv : torch .Tensor ) -> torch .Tensor :
630+ # mixed_qkv: (batch, seq_len, conv_dim)
631+ batch_size , seq_len , _ = mixed_qkv .shape
632+ mixed_qkv = mixed_qkv .transpose (1 , 2 )
633+ state_len = self .conv_state .shape [- 1 ]
634+ hidden_states_new = torch .cat ([self .conv_state [:batch_size ], mixed_qkv ], dim = - 1 )
635+ new_conv_state = hidden_states_new [:, :, - state_len :]
636+ with torch .no_grad ():
637+ self .conv_state [:batch_size ].copy_ (new_conv_state .to (self .conv_state .dtype ))
638+ out = F .conv1d (
639+ hidden_states_new ,
640+ self .conv1d .weight ,
641+ self .conv1d .bias ,
642+ padding = 0 ,
643+ groups = self .conv_dim ,
644+ )
645+ out = F .silu (out [:, :, - seq_len :]).to (mixed_qkv .dtype )
646+ return out .transpose (1 , 2 ).contiguous ()
647+
648+ def _recurrent_gated_delta_rule (
649+ self ,
650+ query : torch .Tensor ,
651+ key : torch .Tensor ,
652+ value : torch .Tensor ,
653+ g : torch .Tensor ,
654+ beta : torch .Tensor ,
655+ ) -> torch .Tensor :
656+ # query/key/value: (batch, seq_len, num_heads, head_dim)
657+ # g/beta: (batch, seq_len, num_heads)
658+ initial_dtype = query .dtype
659+ query = _l2norm (query , dim = - 1 , eps = 1e-6 )
660+ key = _l2norm (key , dim = - 1 , eps = 1e-6 )
661+ query , key , value , beta , g = [
662+ x .transpose (1 , 2 ).contiguous ().to (torch .float32 )
663+ for x in (query , key , value , beta , g )
664+ ]
665+
666+ batch_size , num_heads , sequence_length , k_head_dim = key .shape
667+ v_head_dim = value .shape [- 1 ]
668+ scale = 1.0 / (query .shape [- 1 ] ** 0.5 )
669+ query = query * scale
670+
671+ core_attn_out = torch .zeros (
672+ batch_size ,
673+ num_heads ,
674+ sequence_length ,
675+ v_head_dim ,
676+ device = value .device ,
677+ dtype = value .dtype ,
678+ )
679+ last_recurrent_state = self .recurrent_state [:batch_size ].to (value .dtype )
680+
681+ for i in range (sequence_length ):
682+ q_t = query [:, :, i ]
683+ k_t = key [:, :, i ]
684+ v_t = value [:, :, i ]
685+ g_t = g [:, :, i ].exp ().unsqueeze (- 1 ).unsqueeze (- 1 )
686+ beta_t = beta [:, :, i ].unsqueeze (- 1 )
687+
688+ last_recurrent_state = last_recurrent_state * g_t
689+ kv_mem = (last_recurrent_state * k_t .unsqueeze (- 1 )).sum (dim = - 2 )
690+ delta = (v_t - kv_mem ) * beta_t
691+ last_recurrent_state = last_recurrent_state + k_t .unsqueeze (
692+ - 1
693+ ) * delta .unsqueeze (- 2 )
694+ core_attn_out [:, :, i ] = (last_recurrent_state * q_t .unsqueeze (- 1 )).sum (
695+ dim = - 2
696+ )
697+
698+ with torch .no_grad ():
699+ self .recurrent_state [:batch_size ].copy_ (
700+ last_recurrent_state .to (self .recurrent_state .dtype )
701+ )
702+
703+ return core_attn_out .transpose (1 , 2 ).contiguous ().to (initial_dtype )
704+
705+ def forward (
706+ self ,
707+ x : torch .Tensor ,
708+ freqs_cos : torch .Tensor ,
709+ freqs_sin : torch .Tensor ,
710+ ** kwargs : ForwardOptions ,
711+ ) -> Tuple [torch .Tensor , Optional [Any ]]:
712+ del freqs_cos
713+ del freqs_sin
714+ input_pos = kwargs .get ("input_pos" )
715+ batch_size , seq_len , _ = x .shape
716+ assert (
717+ batch_size <= self .max_batch_size
718+ ), f"batch_size ({ batch_size } ) exceeds max_batch_size ({ self .max_batch_size } )"
719+
720+ self ._maybe_reset_state (input_pos , batch_size )
721+
722+ mixed_qkv = self .in_proj_qkv (x )
723+ z = self .in_proj_z (x ).reshape (batch_size , seq_len , - 1 , self .head_v_dim )
724+ b = self .in_proj_b (x )
725+ a = self .in_proj_a (x )
726+
727+ mixed_qkv = self ._apply_causal_conv (mixed_qkv )
728+ query , key , value = torch .split (
729+ mixed_qkv ,
730+ [self .key_dim , self .key_dim , self .value_dim ],
731+ dim = - 1 ,
732+ )
733+ query = query .reshape (batch_size , seq_len , - 1 , self .head_k_dim )
734+ key = key .reshape (batch_size , seq_len , - 1 , self .head_k_dim )
735+ value = value .reshape (batch_size , seq_len , - 1 , self .head_v_dim )
736+
737+ if self .head_repeat > 1 :
738+ query = query .repeat_interleave (self .head_repeat , dim = 2 )
739+ key = key .repeat_interleave (self .head_repeat , dim = 2 )
740+
741+ beta = b .sigmoid ()
742+ g = - self .A_log .float ().exp () * F .softplus (a .float () + self .dt_bias )
743+ core_attn_out = self ._recurrent_gated_delta_rule (query , key , value , g , beta )
744+
745+ core_attn_out = core_attn_out .reshape (- 1 , self .head_v_dim )
746+ z = z .reshape (- 1 , self .head_v_dim )
747+ core_attn_out = self .norm (core_attn_out , z )
748+ core_attn_out = core_attn_out .reshape (batch_size , seq_len , - 1 )
749+
750+ return self .out_proj (core_attn_out ), None
751+
752+
514753@register_attention ("skip" )
515754class AttentionSkip (Attention ):
516755 def __init__ (self , * args , ** kwargs ):
0 commit comments