Skip to content

Commit 5863ffa

Browse files
committed
Update on "[ET Device Support] CUDA-native Qwen 3.5 MoE inference with device tensor pipeline"
Integrate the ET device tensor pipeline into the Qwen 3.5 MoE model to eliminate unnecessary H2D/D2H copies during inference. - Export: Multi-method export (`forward` + `sample`) with device memory planning enabled and method-level H2D/D2H skipping. - Runner: Custom CUDA-native inference loop that keeps logits on GPU between forward and sample, reuses CUDA tensors across iterations, and only copies the 8-byte token ID back to CPU for EOS checking. Differential Revision: [D100133933](https://our.internmc.facebook.com/intern/diff/D100133933/) [ghstack-poisoned]
1 parent 90b8a6e commit 5863ffa

2 files changed

Lines changed: 11 additions & 3 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
import torch
1414
import torch.nn as nn
15-
from model import FusedMoEExperts, Qwen35MoE, Qwen35MoEConfig
15+
16+
from executorch.examples.models.qwen3_5_moe.model import (
17+
FusedMoEExperts,
18+
Qwen35MoE,
19+
Qwen35MoEConfig,
20+
)
1621

1722

1823
# ---------------------------------------------------------------------------
@@ -56,7 +61,9 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096):
5661
Returns:
5762
(model, config) ready for export.
5863
"""
59-
from quantize_and_save import load_quantized_state_dict
64+
from executorch.examples.models.qwen3_5_moe.quantize_and_save import (
65+
load_quantized_state_dict,
66+
)
6067

6168
config_path = os.path.join(prequantized_dir, "config.json")
6269
safetensors_path = os.path.join(prequantized_dir, "model.safetensors")
@@ -373,6 +380,7 @@ def _apply_turboquant(model, config):
373380
def export_and_lower(model, config, args):
374381
"""Export model to .pte via torch.export + CUDA backend."""
375382
import torch._inductor.config as inductor_config
383+
376384
from executorch.backends.cuda.cuda_backend import CudaBackend
377385
from executorch.backends.cuda.cuda_partitioner import CudaPartitioner
378386
from executorch.exir import (

examples/models/qwen3_5_moe/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
import torch
2222
import torch.nn as nn
23-
import torch.nn.functional as F
23+
from torch.nn import functional as F
2424

2525

2626
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)