@@ -196,6 +196,12 @@ def load_and_quantize(args): # noqa: C901
196196 # CUDA: quantize experts with packed INT4 for Triton kernel
197197 if args .qlinear or args .qembedding :
198198 _quantize (model , config , args )
199+ # Unwrap torchao tensor subclasses (e.g. AffineQuantizedTensor)
200+ # into parametrized plain tensors so torch.export can handle them.
201+ # Mirrors executorch/export/stages.py SourceTransformStage.
202+ from torchao .utils import unwrap_tensor_subclass
203+
204+ unwrap_tensor_subclass (model )
199205 else :
200206 model .to (dtype = torch .bfloat16 )
201207
@@ -290,6 +296,14 @@ def load_prequantized_model(prequantized_dir, max_seq_len=4096, use_splitk_decod
290296 # requires_grad=True, which fails. Disable grad on all parameters.
291297 for p in model .parameters ():
292298 p .requires_grad_ (False )
299+
300+ # Unwrap any torchao tensor subclasses (e.g. AffineQuantizedTensor) into
301+ # parametrized plain tensors so torch.export can handle them. Mirrors the
302+ # canonical pattern in executorch/export/stages.py SourceTransformStage.
303+ from torchao .utils import unwrap_tensor_subclass
304+
305+ unwrap_tensor_subclass (model )
306+
293307 model .eval ()
294308
295309 print (
@@ -966,6 +980,9 @@ def main(): # noqa: C901
966980 # Register FLA Triton kernel (CUDA only)
967981 import executorch .backends .cuda .triton .kernels # noqa: F401
968982
983+ if torch .cuda .is_available ():
984+ torch .cuda .reset_peak_memory_stats ()
985+
969986 if args .backend == "mlx" :
970987 if args .prequantized :
971988 parser .error ("--prequantized is not supported with --backend mlx" )
@@ -988,6 +1005,14 @@ def main(): # noqa: C901
9881005
9891006 export_and_lower (model , config , args )
9901007
1008+ if args .backend == "cuda" and torch .cuda .is_available ():
1009+ peak_alloc_gb = torch .cuda .max_memory_allocated () / 1e9
1010+ peak_reserved_gb = torch .cuda .max_memory_reserved () / 1e9
1011+ print (
1012+ f"[CUDA peak memory] allocated={ peak_alloc_gb :.2f} GB, "
1013+ f"reserved={ peak_reserved_gb :.2f} GB"
1014+ )
1015+
9911016
9921017if __name__ == "__main__" :
9931018 main ()
0 commit comments