Skip to content

Commit f21a065

Browse files
authored
Fix AWQ decode kernel shared memory launch (#2801)
1 parent bb87910 commit f21a065

2 files changed

Lines changed: 10 additions & 97 deletions

File tree

gptqmodel/nn_modules/qlinear/gemv_fast_awq.py

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -153,11 +153,6 @@ def __init__(
153153
else:
154154
self.bias = None
155155

156-
# Blackwell/SM120 currently misbehaves with the fused AWQ kernels, so
157-
# those devices rebuild one dense compatibility weight per module.
158-
self._sm120_compat_weight: torch.Tensor | None = None
159-
self._sm120_compat_weight_device: tuple[torch.device, torch.dtype] | None = None
160-
161156
def forward(self, x: torch.Tensor):
162157
if not awq_runtime_available():
163158
raise ModuleNotFoundError("AWQ torch.ops kernels are not properly installed. Error: " + awq_runtime_error())
@@ -180,11 +175,6 @@ def forward(self, x: torch.Tensor):
180175

181176
self._ensure_runtime_buffers(device=inputs.device, dtype=inputs.dtype)
182177

183-
# Route SM120 devices through a compatibility implementation until the
184-
# fused decode/prefill kernels are fixed for Blackwell.
185-
if self._use_sm120_compat_path(inputs.device):
186-
return self._sm120_compat_forward(x=x, inputs=inputs, input_dtype=input_dtype)
187-
188178
zeros = self._runtime_zeros()
189179
if inputs_dim == 3 and batch_size < 8 and n_tokens == 1:
190180
out = awq_fast_gemv_forward_decode(
@@ -212,84 +202,9 @@ def forward(self, x: torch.Tensor):
212202

213203
return out
214204

215-
def _use_sm120_compat_path(self, device: torch.device) -> bool:
216-
"""Enable the SM120 compatibility path on Blackwell-class CUDA devices."""
217-
218-
if device.type != "cuda":
219-
return False
220-
major, _minor = torch.cuda.get_device_capability(device)
221-
return major >= 12
222-
223-
def _sm120_compat_forward(
224-
self,
225-
*,
226-
x: torch.Tensor,
227-
inputs: torch.Tensor,
228-
input_dtype: torch.dtype,
229-
) -> torch.Tensor:
230-
"""Run a dense compatibility matmul for SM120 until fused kernels are stable."""
231-
232-
out_shape = inputs.shape[:-1] + (self.out_features,)
233-
weight = self._sm120_compat_dense_weight(device=inputs.device, dtype=inputs.dtype)
234-
out = inputs.reshape(-1, inputs.shape[-1]).matmul(weight).reshape(out_shape)
235-
236-
if input_dtype != torch.float16:
237-
out = out.to(dtype=input_dtype)
238-
239-
out = out + self.bias if self.bias is not None else out
240-
241-
if self.adapter:
242-
out = self.adapter.apply(x=x, out=out)
243-
244-
return out
245-
246-
def _sm120_compat_dense_weight(self, *, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
247-
"""Cache one dense AWQ weight matrix per device/dtype for the SM120 path."""
248-
249-
cache_key = (device, dtype)
250-
if self._sm120_compat_weight is not None and self._sm120_compat_weight_device == cache_key:
251-
return self._sm120_compat_weight
252-
253-
intweight = self._unpack_reference_intweight(device=device)
254-
255-
num_groups = max(1, (self.in_features + self.group_size - 1) // self.group_size)
256-
scales = self.scales.transpose(0, 1)[:, :num_groups].to(device=device, dtype=dtype)
257-
zeros = self._runtime_zeros().transpose(0, 1)[:, :num_groups].to(device=device, dtype=dtype)
258-
259-
scales = scales.repeat_interleave(self.group_size, dim=1)[:, : self.in_features]
260-
zeros = zeros.repeat_interleave(self.group_size, dim=1)[:, : self.in_features]
261-
262-
weight = (intweight.to(dtype=dtype) * scales + zeros).transpose(0, 1).contiguous()
263-
self._sm120_compat_weight = weight
264-
self._sm120_compat_weight_device = cache_key
265-
return weight
266-
267-
def _unpack_reference_intweight(self, *, device: torch.device) -> torch.Tensor:
268-
"""Invert the GEMV_FAST int16 packing so SM120 can rebuild dense weights."""
269-
270-
packed = self.qweight.to(device=device, dtype=torch.int32)
271-
unpacked = torch.stack(
272-
[
273-
torch.bitwise_and(torch.bitwise_right_shift(packed, shift), 0xF)
274-
for shift in (0, 4, 8, 12)
275-
],
276-
dim=-1,
277-
)
278-
unpacked = unpacked.view(packed.shape[0], packed.shape[1] // 64, 4, 64)
279-
unpacked = unpacked.permute(0, 2, 1, 3).contiguous()
280-
unpacked = unpacked.view(packed.shape[0] * 4, self.in_features)
281-
unpacked = unpacked.view(packed.shape[0] * 4, self.in_features // 32, 4, 2, 4)
282-
unpacked = unpacked.permute(0, 1, 2, 4, 3).contiguous()
283-
unpacked = unpacked.view(packed.shape[0] * 4, self.in_features // 32, 32)
284-
unpacked = unpacked.view(packed.shape[0] * 4, self.in_features // 32, 4, 4, 2)
285-
unpacked = unpacked.permute(0, 1, 3, 2, 4).contiguous()
286-
return unpacked.view(self.out_features, self.in_features)
287-
288205
def _ensure_runtime_buffers(self, *, device: torch.device, dtype: torch.dtype):
289206
if self.qweight.device != device or not self.qweight.is_contiguous():
290207
self.qweight = self.qweight.to(device=device).contiguous()
291-
self._sm120_compat_weight = None
292-
self._sm120_compat_weight_device = None
293208

294209
zeros = self._runtime_zeros()
295210
if zeros.device != device or zeros.dtype != dtype or not zeros.is_contiguous():
@@ -300,13 +215,9 @@ def _ensure_runtime_buffers(self, *, device: torch.device, dtype: torch.dtype):
300215
self.scaled_zeros = zeros
301216
else:
302217
raise ValueError(f"Unsupported zeros buffer: {self.zeros_name}")
303-
self._sm120_compat_weight = None
304-
self._sm120_compat_weight_device = None
305218

306219
if self.scales.device != device or self.scales.dtype != dtype or not self.scales.is_contiguous():
307220
self.scales = self.scales.to(device=device, dtype=dtype).contiguous()
308-
self._sm120_compat_weight = None
309-
self._sm120_compat_weight_device = None
310221

311222
if self.bias is not None and (
312223
self.bias.device != device or self.bias.dtype != dtype or not self.bias.is_contiguous()

gptqmodel_ext/awq/quantization_new/gemv/gemv_cuda.cu

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,12 @@ torch::Tensor gemv_forward_cuda_decode(
264264
static constexpr int N_PER_BLOCK = 2;
265265
static constexpr int K_INTERLEAVE = 4;
266266
static constexpr int BLOCK_SIZE = 256;
267+
static constexpr int NUM_WARPS = BLOCK_SIZE / WARP_SIZE;
267268

268269
dim3 num_blocks(n / N_PER_BLOCK / K_INTERLEAVE);
269270
dim3 num_threads(BLOCK_SIZE);
271+
// warp_reduce() writes one float accumulator tile per warp into extern shared memory
272+
size_t smem_size = sizeof(float) * NUM_WARPS * N_PER_BLOCK * m * K_INTERLEAVE;
270273

271274
// if (group_size == 64)
272275
// {
@@ -282,37 +285,37 @@ torch::Tensor gemv_forward_cuda_decode(
282285
switch (m)
283286
{
284287
case 1:
285-
gemv_kernel<N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
288+
gemv_kernel<N_PER_BLOCK, 1, BLOCK_SIZE, 128><<<num_blocks, num_threads, smem_size>>>(
286289
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
287290
);
288291
break;
289292
case 2:
290-
gemv_kernel<N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
293+
gemv_kernel<N_PER_BLOCK, 2, BLOCK_SIZE, 128><<<num_blocks, num_threads, smem_size>>>(
291294
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
292295
);
293296
break;
294297
case 3:
295-
gemv_kernel<N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
298+
gemv_kernel<N_PER_BLOCK, 3, BLOCK_SIZE, 128><<<num_blocks, num_threads, smem_size>>>(
296299
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
297300
);
298301
break;
299302
case 4:
300-
gemv_kernel<N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
303+
gemv_kernel<N_PER_BLOCK, 4, BLOCK_SIZE, 128><<<num_blocks, num_threads, smem_size>>>(
301304
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
302305
);
303306
break;
304307
case 5:
305-
gemv_kernel<N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
308+
gemv_kernel<N_PER_BLOCK, 5, BLOCK_SIZE, 128><<<num_blocks, num_threads, smem_size>>>(
306309
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
307310
);
308311
break;
309312
case 6:
310-
gemv_kernel<N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
313+
gemv_kernel<N_PER_BLOCK, 6, BLOCK_SIZE, 128><<<num_blocks, num_threads, smem_size>>>(
311314
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
312315
);
313316
break;
314317
case 7:
315-
gemv_kernel<N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads>>>(
318+
gemv_kernel<N_PER_BLOCK, 7, BLOCK_SIZE, 128><<<num_blocks, num_threads, smem_size>>>(
316319
in_feats, kernel, scaling_factors, zeros, out_feats, k, n
317320
);
318321
break;
@@ -326,4 +329,3 @@ torch::Tensor gemv_forward_cuda_decode(
326329
}
327330
return _out_feats;
328331
}
329-

0 commit comments

Comments
 (0)