Skip to content

Commit a9bf006

Browse files
default run all tests
1 parent bac060e commit a9bf006

1 file changed

Lines changed: 14 additions & 5 deletions

File tree

tests/test_ParoQuant.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,11 @@
4646

4747
def parse_args() -> argparse.Namespace:
4848
parser = argparse.ArgumentParser(description="ParoQuant repro helper")
49-
parser.add_argument("--mode", choices=("jit", "quantize"), default="jit")
49+
parser.add_argument("--mode", choices=("all", "jit", "quantize"), default="all")
5050
parser.add_argument("--rebuild", action="store_true", help="clear ParoQuant JIT cache before probing")
5151
parser.add_argument(
5252
"--model",
53-
default="/monster/data/model/Qwen3.5-27B", # "/monster/data/model/Qwen3-0.6B-Base",
53+
default="/monster/data/model/Qwen3-0.6B-Base",
5454
help="local model path for quantize mode",
5555
)
5656
parser.add_argument(
@@ -66,7 +66,7 @@ def parse_args() -> argparse.Namespace:
6666
parser.add_argument("--opt-rotation-epochs", type=int, default=1)
6767
parser.add_argument("--opt-finetune-epochs", type=int, default=1)
6868
parser.add_argument("--opt-train-samples", type=int, default=8)
69-
parser.add_argument("--opt-validation-samples", type=int, default=0)
69+
parser.add_argument("--opt-validation-samples", type=int, default=1)
7070
parser.add_argument("--opt-batch-size", type=int, default=4)
7171
parser.add_argument("--dtype", choices=("auto", "bfloat16", "float16"), default="bfloat16")
7272
return parser.parse_args()
@@ -90,7 +90,7 @@ def run_jit_repro(*, rebuild: bool) -> int:
9090
print("CUDA is not available, skip ParoQuant repro.")
9191
return 2
9292

93-
print("\n== Rebuild ParoQuant rotation extension ==")
93+
print("\n== JIT Repro ==")
9494
if rebuild:
9595
clear_paroquant_rotation_extension_cache()
9696
started = time.perf_counter()
@@ -102,7 +102,7 @@ def run_jit_repro(*, rebuild: bool) -> int:
102102
if not ok:
103103
return 1
104104

105-
print("\n== Run one fused rotation call ==")
105+
print("\n== Fused Rotation Probe ==")
106106
device = torch.device("cuda:0")
107107
x = torch.randn(32, 128, device=device, dtype=torch.bfloat16)
108108
pairs, theta, scales = build_identity_rotation_buffers(
@@ -204,6 +204,15 @@ def main() -> int:
204204
args = parse_args()
205205
if args.mode == "jit":
206206
return run_jit_repro(rebuild=args.rebuild)
207+
if args.mode == "quantize":
208+
return run_quantize_repro(args)
209+
210+
jit_result = run_jit_repro(rebuild=args.rebuild)
211+
if jit_result != 0:
212+
print("\nJIT repro failed; skip quantize repro.")
213+
return jit_result
214+
215+
print("\n== Proceed To Quantize Repro ==")
207216
return run_quantize_repro(args)
208217

209218

0 commit comments

Comments
 (0)