4646
4747def parse_args () -> argparse .Namespace :
4848 parser = argparse .ArgumentParser (description = "ParoQuant repro helper" )
49- parser .add_argument ("--mode" , choices = ("jit" , "quantize" ), default = "jit " )
49+ parser .add_argument ("--mode" , choices = ("all" , " jit" , "quantize" ), default = "all " )
5050 parser .add_argument ("--rebuild" , action = "store_true" , help = "clear ParoQuant JIT cache before probing" )
5151 parser .add_argument (
5252 "--model" ,
53- default = "/monster/data/model/Qwen3.5-27B" , # "/monster/data/model/Qwen3 -0.6B-Base",
53+ default = "/monster/data/model/Qwen3-0.6B-Base" ,
5454 help = "local model path for quantize mode" ,
5555 )
5656 parser .add_argument (
@@ -66,7 +66,7 @@ def parse_args() -> argparse.Namespace:
6666 parser .add_argument ("--opt-rotation-epochs" , type = int , default = 1 )
6767 parser .add_argument ("--opt-finetune-epochs" , type = int , default = 1 )
6868 parser .add_argument ("--opt-train-samples" , type = int , default = 8 )
69- parser .add_argument ("--opt-validation-samples" , type = int , default = 0 )
69+ parser .add_argument ("--opt-validation-samples" , type = int , default = 1 )
7070 parser .add_argument ("--opt-batch-size" , type = int , default = 4 )
7171 parser .add_argument ("--dtype" , choices = ("auto" , "bfloat16" , "float16" ), default = "bfloat16" )
7272 return parser .parse_args ()
@@ -90,7 +90,7 @@ def run_jit_repro(*, rebuild: bool) -> int:
9090 print ("CUDA is not available, skip ParoQuant repro." )
9191 return 2
9292
93- print ("\n == Rebuild ParoQuant rotation extension ==" )
93+ print ("\n == JIT Repro ==" )
9494 if rebuild :
9595 clear_paroquant_rotation_extension_cache ()
9696 started = time .perf_counter ()
@@ -102,7 +102,7 @@ def run_jit_repro(*, rebuild: bool) -> int:
102102 if not ok :
103103 return 1
104104
105- print ("\n == Run one fused rotation call ==" )
105+ print ("\n == Fused Rotation Probe ==" )
106106 device = torch .device ("cuda:0" )
107107 x = torch .randn (32 , 128 , device = device , dtype = torch .bfloat16 )
108108 pairs , theta , scales = build_identity_rotation_buffers (
@@ -204,6 +204,15 @@ def main() -> int:
204204 args = parse_args ()
205205 if args .mode == "jit" :
206206 return run_jit_repro (rebuild = args .rebuild )
207+ if args .mode == "quantize" :
208+ return run_quantize_repro (args )
209+
210+ jit_result = run_jit_repro (rebuild = args .rebuild )
211+ if jit_result != 0 :
212+ print ("\n JIT repro failed; skip quantize repro." )
213+ return jit_result
214+
215+ print ("\n == Proceed To Quantize Repro ==" )
207216 return run_quantize_repro (args )
208217
209218
0 commit comments