Skip to content

Commit 18c9659

Browse files
committed
# Fix for matmul_4bit out Parameter Issue
1 parent d475533 commit 18c9659

2 files changed

Lines changed: 40 additions & 15 deletions

File tree

bitsandbytes/autograd/_functions.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -299,9 +299,6 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
299299

300300

301301
class MatMul4Bit(torch.autograd.Function):
302-
# forward is the same, but we added the fallback for pre-turing GPUs
303-
# backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
304-
305302
@staticmethod
306303
def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState] = None):
307304
# default of pytorch behavior if inputs are empty
@@ -319,7 +316,15 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
319316

320317
# 1. Dequantize
321318
# 2. MatmulnN
322-
output = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
319+
# Use linear function which correctly handles 1D and 2D inputs
320+
result = torch.nn.functional.linear(A, F.dequantize_4bit(B, quant_state).to(A.dtype).t(), bias)
321+
322+
# If out is provided, resize it if necessary and copy the result
323+
if out is not None:
324+
if out.shape != result.shape:
325+
out.resize_(result.shape)
326+
out.copy_(result)
327+
result = out
323328

324329
# 3. Save state
325330
ctx.state = quant_state
@@ -330,7 +335,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
330335
else:
331336
ctx.tensors = (None, None)
332337

333-
return output
338+
return result
334339

335340
@staticmethod
336341
def backward(ctx, grad_output):
@@ -385,9 +390,14 @@ def matmul_4bit(
385390
)
386391
return MatMul4Bit.apply(A, B, out, bias, quant_state)
387392
else:
388-
out = F.gemv_4bit(A, B.t(), out, state=quant_state)
393+
# For 1D case, we'll use the MatMul4Bit implementation which correctly handles out parameter
394+
if out is not None and A.dim() == 1:
395+
return MatMul4Bit.apply(A, B, out, bias, quant_state)
396+
397+
# For other cases, use gemv_4bit
398+
result = F.gemv_4bit(A, B.t(), out, state=quant_state)
389399
if bias is not None:
390-
out += bias
391-
return out
400+
result += bias
401+
return result
392402
else:
393403
return MatMul4Bit.apply(A, B, out, bias, quant_state)

bitsandbytes/backends/cuda/ops.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -427,11 +427,18 @@ def _(
427427
blocksize: int,
428428
out: torch.Tensor,
429429
) -> None:
430+
expected_shape = (*A.shape[:-1], shapeB[0])
431+
432+
if len(A.shape) == 1 and len(out.shape) == 2 and out.shape[0] == 1:
433+
out = out.view(shapeB[0])
434+
expected_shape = (shapeB[0],)
435+
430436
torch._check(
431-
out.shape == (*A.shape[:-1], shapeB[0]),
432-
lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
437+
out.shape == expected_shape,
438+
lambda: f"Expected out.shape == {expected_shape}, got {out.shape}",
433439
)
434440
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
441+
435442
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
436443

437444

@@ -445,10 +452,13 @@ def _gemv_4bit_impl(
445452
out: torch.Tensor,
446453
) -> None:
447454
torch._check_is_size(blocksize)
448-
torch._check(
449-
A.numel() == A.size(-1),
450-
lambda: f"A must be a vector with leading dimensions of 1, got {A.shape}",
451-
)
455+
456+
is_1d = A.dim() == 1
457+
if is_1d:
458+
A_reshaped = A.view(1, -1)
459+
else:
460+
A_reshaped = A
461+
452462
torch._check(
453463
A.dtype in [torch.float16, torch.bfloat16, torch.float32],
454464
lambda: f"A must be float16, bfloat16, or float32, got {A.dtype}",
@@ -465,11 +475,16 @@ def _gemv_4bit_impl(
465475
k = ct.c_int32(shapeB[1])
466476

467477
lda = m
468-
ldb = ct.c_int32((A.shape[-1] + 1) // 2)
478+
ldb = ct.c_int32((A_reshaped.shape[-1] + 1) // 2)
469479
ldc = m
470480

471481
stream = _get_tensor_stream(A)
472482

483+
if is_1d and out.dim() > 1:
484+
out_view = out.view(-1)
485+
else:
486+
out_view = out
487+
473488
with _cuda_device_of(A):
474489
if A.dtype == torch.float16:
475490
lib.cgemm_4bit_inference_naive_fp16(

0 commit comments

Comments
 (0)