Skip to content

Commit 1df5d56

Browse files
committed
up
1 parent 7b5dcc1 commit 1df5d56

1 file changed

Lines changed: 25 additions & 0 deletions

File tree

examples/models/qwen3_5_moe/export.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,12 @@ def load_and_quantize(args): # noqa: C901
196196
# CUDA: quantize experts with packed INT4 for Triton kernel
197197
if args.qlinear or args.qembedding:
198198
_quantize(model, config, args)
199+
# Unwrap torchao tensor subclasses (e.g. AffineQuantizedTensor)
200+
# into parametrized plain tensors so torch.export can handle them.
201+
# Mirrors executorch/export/stages.py SourceTransformStage.
202+
from torchao.utils import unwrap_tensor_subclass
203+
204+
unwrap_tensor_subclass(model)
199205
else:
200206
model.to(dtype=torch.bfloat16)
201207

@@ -290,6 +296,14 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096, use_splitk_decod
290296
# requires_grad=True, which fails. Disable grad on all parameters.
291297
for p in model.parameters():
292298
p.requires_grad_(False)
299+
300+
# Unwrap any torchao tensor subclasses (e.g. AffineQuantizedTensor) into
301+
# parametrized plain tensors so torch.export can handle them. Mirrors the
302+
# canonical pattern in executorch/export/stages.py SourceTransformStage.
303+
from torchao.utils import unwrap_tensor_subclass
304+
305+
unwrap_tensor_subclass(model)
306+
293307
model.eval()
294308

295309
print(
@@ -966,6 +980,9 @@ def main(): # noqa: C901
966980
# Register FLA Triton kernel (CUDA only)
967981
import executorch.backends.cuda.triton.kernels # noqa: F401
968982

983+
if torch.cuda.is_available():
984+
torch.cuda.reset_peak_memory_stats()
985+
969986
if args.backend == "mlx":
970987
if args.prequantized:
971988
parser.error("--prequantized is not supported with --backend mlx")
@@ -988,6 +1005,14 @@ def main(): # noqa: C901
9881005

9891006
export_and_lower(model, config, args)
9901007

1008+
if args.backend == "cuda" and torch.cuda.is_available():
1009+
peak_alloc_gb = torch.cuda.max_memory_allocated() / 1e9
1010+
peak_reserved_gb = torch.cuda.max_memory_reserved() / 1e9
1011+
print(
1012+
f"[CUDA peak memory] allocated={peak_alloc_gb:.2f} GB, "
1013+
f"reserved={peak_reserved_gb:.2f} GB"
1014+
)
1015+
9911016

9921017
if __name__ == "__main__":
9931018
main()

0 commit comments

Comments
 (0)