2121
2222import torch
2323import torch .nn as nn
24-
2524from executorch .examples .models .qwen3_5_moe .sampler import sample
2625from torch .nn import functional as F
2726
@@ -186,7 +185,6 @@ def _apply_rotary(x, cos, sin):
186185
187186
188187class KVCache (nn .Module ):
189-
190188 def __init__ (self , n_kv_heads , head_dim , max_seq_len ):
191189 super ().__init__ ()
192190 self .register_buffer (
@@ -207,7 +205,6 @@ def update(self, input_pos, k_val, v_val):
207205
208206
209207class FullAttention (nn .Module ):
210-
211208 def __init__ (self , config ):
212209 super ().__init__ ()
213210 self .n_heads = config .num_attention_heads
@@ -318,7 +315,6 @@ def forward(self, x, input_pos):
318315
319316
320317class GatedDeltaNet (nn .Module ):
321-
322318 def __init__ (self , config ):
323319 super ().__init__ ()
324320 self .num_k_heads = config .linear_num_key_heads
@@ -540,7 +536,6 @@ def forward(self, x):
540536
541537
542538class SparseMoE (nn .Module ):
543-
544539 def __init__ (self , config ):
545540 super ().__init__ ()
546541 self .top_k = config .num_experts_per_tok
@@ -574,7 +569,6 @@ def forward(self, x):
574569
575570
576571class Block (nn .Module ):
577-
578572 def __init__ (self , config , layer_idx ):
579573 super ().__init__ ()
580574 self .layer_type = config .layer_types [layer_idx ]
@@ -599,7 +593,6 @@ def forward(self, x, input_pos):
599593
600594
601595class Qwen35MoE (nn .Module ):
602-
603596 def __init__ (self , config ):
604597 super ().__init__ ()
605598 self .config = config
@@ -625,7 +618,7 @@ def forward(
625618 # position. Otherwise apply the prefill optimization and only
626619 # materialize ``[B, V]`` for the last token.
627620 if temperature is None :
628- return self .lm_head (x ). float () # [B, T, V] float32
621+ return self .lm_head (x ) # [B, T, V] in model dtype
629622 logits = self .lm_head (x [:, - 1 , :]).float () # [B, V] float32
630623 # GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise) is
631624 # equivalent to drawing from softmax(logits/T) but stays entirely
0 commit comments