Skip to content

Commit 01e37b1

Browse files
committed
comment 2nd round
Signed-off-by: Shiyang Chen <shiychen@nvidia.com>
1 parent dfc7d4c commit 01e37b1

2 files changed

Lines changed: 5 additions & 9 deletions

File tree

modelopt/torch/quantization/triton/gptq_fused_kernel.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ def _gptq_scalar_kernel(
5353
n_scale_blocks,
5454
quant_block_size,
5555
block_start,
56-
n_cols,
5756
BLOCK_SIZE: tl.constexpr,
5857
):
5958
row = tl.program_id(0)
@@ -66,7 +65,7 @@ def _gptq_scalar_kernel(
6665
scales_base = scales_ptr + row * n_scale_blocks
6766

6867
j_range = tl.arange(0, BLOCK_SIZE)
69-
w_full = tl.load(w_base + j_range, mask=j_range < n_cols, other=0.0)
68+
w_full = tl.load(w_base + j_range)
7069

7170
for col in range(0, BLOCK_SIZE, 1):
7271
scale = tl.load(scales_base + (block_start + col) // quant_block_size)
@@ -85,7 +84,7 @@ def _gptq_scalar_kernel(
8584
tl.store(err_base + col, err_val)
8685
tl.store(qw_base + col, q_scalar)
8786

88-
remaining = (j_range > col) & (j_range < n_cols)
87+
remaining = j_range > col
8988
hinv_row = tl.load(hinv_ptr + col * BLOCK_SIZE + j_range, mask=remaining, other=0.0)
9089
w_full = w_full - err_val * hinv_row
9190

@@ -96,7 +95,6 @@ def gptq_fused_block_scalar(
9695
h_inv_cho_blk: torch.Tensor,
9796
quant_block_size: int,
9897
block_start: int,
99-
n_cols: int,
10098
) -> tuple[torch.Tensor, torch.Tensor]:
10199
"""Run scalar GPTQ (NVFP4) column loop for one block in a single Triton kernel launch.
102100
@@ -106,7 +104,6 @@ def gptq_fused_block_scalar(
106104
h_inv_cho_blk: Block of upper-Cholesky inverse Hessian ``[block_size, block_size]``.
107105
quant_block_size: Number of elements sharing one scale factor.
108106
block_start: Column offset of this block in the full weight matrix.
109-
n_cols: Number of active columns in this block.
110107
111108
Returns:
112109
``(qw_block, err_block)`` each ``[num_rows, block_size]``.
@@ -126,7 +123,6 @@ def gptq_fused_block_scalar(
126123
scales_2d.shape[1],
127124
quant_block_size,
128125
block_start,
129-
n_cols,
130126
BLOCK_SIZE=block_size,
131127
)
132128

tests/gpu/torch/quantization/test_gptq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def test_gptq_e2e_flow(quant_cfg):
240240
# ---------------------------------------------------------------------------
241241

242242

243+
# TODO(shiychen): This should be extracted out from production code path
243244
def _compute_h_inv(hessian, weight, percdamp=0.01):
244245
"""Compute damped upper-Cholesky inverse Hessian."""
245246
h = hessian.clone()
@@ -272,6 +273,7 @@ def _make_nvfp4_test_data(quant_block_size, out_features, dim):
272273
return weight, scales_2d, h_inv
273274

274275

276+
# TODO(shiychen): This should be extracted out from production code path
275277
def _run_unfused_gptq_nvfp4(weight, scales_2d, h_inv, gptq_block_size, quant_block_size):
276278
"""Unfused NVFP4 GPTQ using the production Triton FP4 kernel per column.
277279
@@ -319,18 +321,16 @@ def _run_fused_gptq_nvfp4(weight, scales_2d, h_inv, gptq_block_size, quant_block
319321
w = weight.float().clone()
320322
for bs in range(0, dim, gptq_block_size):
321323
be = min(bs + gptq_block_size, dim)
322-
nc = be - bs
323324
qw, err = gptq_fused_block_scalar(
324325
w[:, bs:be].clone().contiguous(),
325326
scales_2d,
326327
h_inv[bs:be, bs:be].contiguous(),
327328
quant_block_size,
328329
bs,
329-
nc,
330330
)
331331
w[:, bs:be] = qw
332332
if be < dim:
333-
w[:, be:].addmm_(err[:, :nc], h_inv[bs:be, be:], alpha=-1)
333+
w[:, be:].addmm_(err, h_inv[bs:be, be:], alpha=-1)
334334
return w
335335

336336

0 commit comments

Comments
 (0)