Skip to content

Commit dfc7d4c

Browse files
committed
address reviewer comments
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent fc14872 commit dfc7d4c

2 files changed

Lines changed: 29 additions & 41 deletions

File tree

modelopt/torch/quantization/triton/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
# fp4_kernel works on any CUDA GPU with triton
3434
from .fp4_kernel import *
3535
from .fp8_kernel import *
36-
from .nvfp4_quant import *
3736

3837
# fp4_kernel_hopper requires compute >= 8.9 (uses tl.float8e4nv)
3938
if torch.cuda.get_device_capability() >= (8, 9):

tests/gpu/torch/quantization/test_gptq.py

Lines changed: 29 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from modelopt.torch.export.unified_export_hf import _export_quantized_weight
2626
from modelopt.torch.quantization.model_calib import gptq
2727
from modelopt.torch.quantization.qtensor.nvfp4_tensor import NVFP4QTensor
28-
from modelopt.torch.quantization.triton.fp4_kernel import compute_fp4_scales
2928
from modelopt.torch.quantization.utils.calib_utils import update_hessian
3029
from modelopt.torch.utils.dataset_utils import create_forward_loop, get_dataset_dataloader
3130

@@ -257,6 +256,8 @@ def _compute_h_inv(hessian, weight, percdamp=0.01):
257256

258257
def _make_nvfp4_test_data(quant_block_size, out_features, dim):
259258
"""Create weight, h_inv, and scales_2d for NVFP4 GPTQ tests."""
259+
from modelopt.torch.quantization.triton.fp4_kernel import compute_fp4_scales
260+
260261
weight = torch.randn(out_features, dim, device="cuda", dtype=torch.float32)
261262
n_blocks = dim // quant_block_size
262263
amax = weight.reshape(out_features, n_blocks, quant_block_size).abs().amax(dim=-1)
@@ -393,24 +394,12 @@ def test_fused_vs_unfused_nvfp4(quant_block_size, gptq_block_size):
393394
(128, 128, 256, 4096),
394395
]
395396

396-
_NVFP4_BENCH_IDS = [f"qbs{qbs}_gbs{gbs}_{of}x{d}" for qbs, gbs, of, d in _NVFP4_BENCH_CONFIGS]
397-
398397

399-
@requires_triton
400-
@pytest.mark.parametrize(
401-
("quant_block_size", "gptq_block_size", "out_features", "dim"),
402-
_NVFP4_BENCH_CONFIGS,
403-
ids=_NVFP4_BENCH_IDS,
404-
)
405-
def test_fused_nvfp4_benchmark(quant_block_size, gptq_block_size, out_features, dim):
406-
"""Benchmark fused Triton NVFP4 GPTQ vs unfused production loop."""
407-
torch.manual_seed(42)
398+
def bench_fused_nvfp4():
399+
"""Benchmark fused Triton NVFP4 GPTQ vs unfused production loop (informational-only).
408400
409-
weight, scales_2d, h_inv = _make_nvfp4_test_data(
410-
quant_block_size,
411-
out_features,
412-
dim,
413-
)
401+
Not collected by pytest. Run directly: ``python tests/gpu/torch/quantization/test_gptq.py``
402+
"""
414403

415404
def _bench(fn, n_warmup=2, n_iters=5):
416405
for _ in range(n_warmup):
@@ -425,30 +414,30 @@ def _bench(fn, n_warmup=2, n_iters=5):
425414
total += time.perf_counter() - t0
426415
return total / n_iters
427416

428-
def run_fused():
429-
return _run_fused_gptq_nvfp4(
430-
weight,
431-
scales_2d,
432-
h_inv,
433-
gptq_block_size,
434-
quant_block_size,
435-
)
417+
for quant_block_size, gptq_block_size, out_features, dim in _NVFP4_BENCH_CONFIGS:
418+
torch.manual_seed(42)
419+
weight, scales_2d, h_inv = _make_nvfp4_test_data(quant_block_size, out_features, dim)
436420

437-
def run_unfused():
438-
return _run_unfused_gptq_nvfp4(
439-
weight,
440-
scales_2d,
441-
h_inv,
442-
gptq_block_size,
443-
quant_block_size,
421+
def run_fused():
422+
return _run_fused_gptq_nvfp4(
423+
weight, scales_2d, h_inv, gptq_block_size, quant_block_size
424+
)
425+
426+
def run_unfused():
427+
return _run_unfused_gptq_nvfp4(
428+
weight, scales_2d, h_inv, gptq_block_size, quant_block_size
429+
)
430+
431+
t_fused = _bench(run_fused)
432+
t_unfused = _bench(run_unfused)
433+
speedup = t_unfused / t_fused if t_fused > 0 else float("inf")
434+
435+
tag = f"qbs{quant_block_size}_gbs{gptq_block_size}_{out_features}x{dim}"
436+
print(
437+
f"[{tag}] Fused: {t_fused * 1e3:8.2f} ms | "
438+
f"Unfused: {t_unfused * 1e3:8.2f} ms | Speedup: {speedup:.1f}x"
444439
)
445440

446-
t_fused = _bench(run_fused)
447-
t_unfused = _bench(run_unfused)
448-
speedup = t_unfused / t_fused if t_fused > 0 else float("inf")
449441

450-
tag = f"qbs{quant_block_size}_gbs{gptq_block_size}_{out_features}x{dim}"
451-
print(
452-
f"\n[{tag}] Fused: {t_fused * 1e3:8.2f} ms | "
453-
f"Unfused: {t_unfused * 1e3:8.2f} ms | Speedup: {speedup:.1f}x"
454-
)
442+
if __name__ == "__main__":
443+
bench_fused_nvfp4()

0 commit comments

Comments
 (0)