Skip to content

Commit fe403f5

Browse files
add test file
1 parent 3a2becf commit fe403f5

1 file changed

Lines changed: 211 additions & 0 deletions

File tree

tests/tests_ParoQuant.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
4+
from __future__ import annotations
5+
6+
import argparse
7+
import os
8+
import sys
9+
import time
10+
from pathlib import Path
11+
12+
import torch
13+
from torch.utils.cpp_extension import CUDA_HOME
14+
15+
from gptqmodel import GPTQModel
16+
from gptqmodel import extension
17+
from gptqmodel.quantization.config import FORMAT, METHOD, ParoConfig
18+
from gptqmodel.utils.paroquant import (
19+
apply_paroquant_rotation,
20+
build_identity_rotation_buffers,
21+
clear_paroquant_rotation_extension_cache,
22+
)
23+
24+
CALIBRATION_TEXTS = [
25+
"Summarize the role of CUDA kernel compilation in PyTorch custom operators.",
26+
"Explain why a quantization fallback path can make model compression much slower.",
27+
"Qwen models are decoder-only transformers optimized for generation workloads.",
28+
"ParoQuant applies pairwise rotations before quantization to reduce approximation error.",
29+
"A small calibration set is enough for reproducing failures even when accuracy is not the goal.",
30+
"The purpose of this run is to reproduce the issue path, not to measure final model quality.",
31+
"When a JIT extension fails instantly, the root cause is often toolchain discovery rather than CUDA execution.",
32+
"Quantization logs should clearly distinguish compilation failures from runtime numerical problems.",
33+
]
34+
35+
36+
# python tests_ParoQuant.py --mode quantize \
37+
# --model /monster/data/model/Qwen3.5-27B \
38+
# --output-dir /tmp/paroquant_qwen3_0_6b_test \
39+
# --calibration-samples 8 \
40+
# --batch-size 1 \
41+
# --opt-rotation-epochs 1 \
42+
# --opt-finetune-epochs 1 \
43+
# --opt-train-samples 8 \
44+
# --opt-validation-samples 1 \
45+
# --opt-batch-size 4
46+
47+
def parse_args() -> argparse.Namespace:
48+
parser = argparse.ArgumentParser(description="ParoQuant repro helper")
49+
parser.add_argument("--mode", choices=("jit", "quantize"), default="jit")
50+
parser.add_argument("--rebuild", action="store_true", help="clear ParoQuant JIT cache before probing")
51+
parser.add_argument(
52+
"--model",
53+
default="/monster/data/model/Qwen3.5-27B", # "/monster/data/model/Qwen3-0.6B-Base",
54+
help="local model path for quantize mode",
55+
)
56+
parser.add_argument(
57+
"--output-dir",
58+
default="/tmp/paroquant_qwen3_0_6b_test",
59+
help="save path for quantized model",
60+
)
61+
parser.add_argument("--bits", type=int, default=4)
62+
parser.add_argument("--group-size", type=int, default=128)
63+
parser.add_argument("--batch-size", type=int, default=1)
64+
parser.add_argument("--calibration-samples", type=int, default=8)
65+
parser.add_argument("--calibration-concat-size", type=int, default=0)
66+
parser.add_argument("--opt-rotation-epochs", type=int, default=1)
67+
parser.add_argument("--opt-finetune-epochs", type=int, default=1)
68+
parser.add_argument("--opt-train-samples", type=int, default=8)
69+
parser.add_argument("--opt-validation-samples", type=int, default=0)
70+
parser.add_argument("--opt-batch-size", type=int, default=4)
71+
parser.add_argument("--dtype", choices=("auto", "bfloat16", "float16"), default="bfloat16")
72+
return parser.parse_args()
73+
74+
75+
def print_environment() -> None:
76+
print("== Environment ==")
77+
print(f"python={sys.version}")
78+
print(f"torch={torch.__version__}")
79+
print(f"torch_cuda={torch.version.cuda}")
80+
print(f"cuda_home={CUDA_HOME}")
81+
print(f"cuda_available={torch.cuda.is_available()}")
82+
print(f"device_count={torch.cuda.device_count()}")
83+
for idx in range(torch.cuda.device_count()):
84+
print(f"device[{idx}] capability={torch.cuda.get_device_capability(idx)}")
85+
86+
87+
def run_jit_repro(*, rebuild: bool) -> int:
88+
print_environment()
89+
if not torch.cuda.is_available():
90+
print("CUDA is not available, skip ParoQuant repro.")
91+
return 2
92+
93+
print("\n== Rebuild ParoQuant rotation extension ==")
94+
if rebuild:
95+
clear_paroquant_rotation_extension_cache()
96+
started = time.perf_counter()
97+
ok = extension.is_available("paroquant", use_cache=not rebuild)
98+
elapsed = time.perf_counter() - started
99+
print(f"is_available={ok}")
100+
print(f"elapsed={elapsed:.3f}s")
101+
print(f"error={extension.error('paroquant')}")
102+
if not ok:
103+
return 1
104+
105+
print("\n== Run one fused rotation call ==")
106+
device = torch.device("cuda:0")
107+
x = torch.randn(32, 128, device=device, dtype=torch.bfloat16)
108+
pairs, theta, scales = build_identity_rotation_buffers(
109+
in_features=128,
110+
group_size=128,
111+
krot=1,
112+
device=device,
113+
dtype=torch.bfloat16,
114+
)
115+
y = apply_paroquant_rotation(x, pairs, theta, scales, group_size=128)
116+
print(f"output_shape={tuple(y.shape)}")
117+
print(f"output_dtype={y.dtype}")
118+
print(f"max_abs_diff={(y - x).abs().max().item():.6f}")
119+
return 0
120+
121+
122+
def _resolve_dtype(name: str):
123+
if name == "auto":
124+
return "auto"
125+
if name == "bfloat16":
126+
return torch.bfloat16
127+
if name == "float16":
128+
return torch.float16
129+
raise ValueError(f"unsupported dtype: {name}")
130+
131+
132+
def run_quantize_repro(args: argparse.Namespace) -> int:
133+
print_environment()
134+
model_path = Path(args.model)
135+
if not model_path.exists():
136+
print(f"model path does not exist: {model_path}")
137+
return 2
138+
if not torch.cuda.is_available():
139+
print("CUDA is not available, skip quantize repro.")
140+
return 2
141+
142+
calibration_dataset = CALIBRATION_TEXTS[: args.calibration_samples]
143+
print("\n== Quantize Setup ==")
144+
print(f"model={model_path}")
145+
print(f"output_dir={args.output_dir}")
146+
print(f"calibration_samples={len(calibration_dataset)}")
147+
print(f"batch_size={args.batch_size}")
148+
149+
qcfg = ParoConfig(
150+
bits=args.bits,
151+
group_size=args.group_size,
152+
method=METHOD.PARO,
153+
format=FORMAT.PAROQUANT,
154+
opt_scope="module",
155+
opt_rotation_epochs=args.opt_rotation_epochs,
156+
opt_finetune_epochs=args.opt_finetune_epochs,
157+
opt_train_samples=args.opt_train_samples,
158+
opt_validation_samples=args.opt_validation_samples,
159+
opt_batch_size=args.opt_batch_size,
160+
opt_pair_impl="fast",
161+
opt_quantizer_impl="reference",
162+
opt_stage_impl="fast",
163+
offload_to_disk=True,
164+
)
165+
166+
if args.rebuild:
167+
clear_paroquant_rotation_extension_cache()
168+
169+
print("\n== Load Model ==")
170+
load_started = time.perf_counter()
171+
model = GPTQModel.load(
172+
str(model_path),
173+
quantize_config=qcfg,
174+
trust_remote_code=False,
175+
dtype=_resolve_dtype(args.dtype),
176+
)
177+
print(f"load_elapsed={time.perf_counter() - load_started:.3f}s")
178+
179+
print("\n== Quantize ==")
180+
quant_started = time.perf_counter()
181+
quant_logs = model.quantize(
182+
calibration_dataset,
183+
batch_size=args.batch_size,
184+
calibration_concat_size=args.calibration_concat_size,
185+
calibration_sort="desc",
186+
)
187+
quant_elapsed = time.perf_counter() - quant_started
188+
print(f"quant_elapsed={quant_elapsed:.3f}s")
189+
print(f"quant_log_keys={sorted(quant_logs.keys()) if isinstance(quant_logs, dict) else type(quant_logs).__name__}")
190+
print(f"paroquant_extension_error={extension.error('paroquant')}")
191+
192+
print("\n== Save ==")
193+
output_dir = Path(args.output_dir)
194+
output_dir.mkdir(parents=True, exist_ok=True)
195+
save_started = time.perf_counter()
196+
model.save(str(output_dir))
197+
print(f"save_elapsed={time.perf_counter() - save_started:.3f}s")
198+
print(f"saved_to={output_dir}")
199+
return 0
200+
201+
202+
def main() -> int:
203+
os.environ.setdefault("GPTQMODEL_EXT_VERBOSE", "1")
204+
args = parse_args()
205+
if args.mode == "jit":
206+
return run_jit_repro(rebuild=args.rebuild)
207+
return run_quantize_repro(args)
208+
209+
210+
if __name__ == "__main__":
211+
raise SystemExit(main())

0 commit comments

Comments
 (0)