Skip to content

Commit 15589f3

Browse files
committed
make moe output dtype consistent on non-cuda backends
1 parent 3a62fac commit 15589f3

1 file changed

Lines changed: 1 addition & 8 deletions

File tree

examples/models/qwen3_5_moe/model.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222
import torch
2323
import torch.nn as nn
24-
2524
from executorch.examples.models.qwen3_5_moe.sampler import sample
2625
from torch.nn import functional as F
2726

@@ -186,7 +185,6 @@ def _apply_rotary(x, cos, sin):
186185

187186

188187
class KVCache(nn.Module):
189-
190188
def __init__(self, n_kv_heads, head_dim, max_seq_len):
191189
super().__init__()
192190
self.register_buffer(
@@ -207,7 +205,6 @@ def update(self, input_pos, k_val, v_val):
207205

208206

209207
class FullAttention(nn.Module):
210-
211208
def __init__(self, config):
212209
super().__init__()
213210
self.n_heads = config.num_attention_heads
@@ -318,7 +315,6 @@ def forward(self, x, input_pos):
318315

319316

320317
class GatedDeltaNet(nn.Module):
321-
322318
def __init__(self, config):
323319
super().__init__()
324320
self.num_k_heads = config.linear_num_key_heads
@@ -540,7 +536,6 @@ def forward(self, x):
540536

541537

542538
class SparseMoE(nn.Module):
543-
544539
def __init__(self, config):
545540
super().__init__()
546541
self.top_k = config.num_experts_per_tok
@@ -574,7 +569,6 @@ def forward(self, x):
574569

575570

576571
class Block(nn.Module):
577-
578572
def __init__(self, config, layer_idx):
579573
super().__init__()
580574
self.layer_type = config.layer_types[layer_idx]
@@ -599,7 +593,6 @@ def forward(self, x, input_pos):
599593

600594

601595
class Qwen35MoE(nn.Module):
602-
603596
def __init__(self, config):
604597
super().__init__()
605598
self.config = config
@@ -625,7 +618,7 @@ def forward(
625618
# position. Otherwise apply the prefill optimization and only
626619
# materialize ``[B, V]`` for the last token.
627620
if temperature is None:
628-
return self.lm_head(x).float() # [B, T, V] float32
621+
return self.lm_head(x) # [B, T, V] in model dtype
629622
logits = self.lm_head(x[:, -1, :]).float() # [B, V] float32
630623
# GPU-side Gumbel-max sampling: argmax(logits/T + gumbel_noise) is
631624
# equivalent to drawing from softmax(logits/T) but stays entirely

0 commit comments

Comments
 (0)