@@ -299,9 +299,6 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor
299299
300300
301301class MatMul4Bit (torch .autograd .Function ):
302- # forward is the same, but we added the fallback for pre-turing GPUs
303- # backward is mostly the same, but adds one extra clause (see "elif state.CxB is not None")
304-
305302 @staticmethod
306303 def forward (ctx , A , B , out = None , bias = None , quant_state : Optional [F .QuantState ] = None ):
307304 # default of pytorch behavior if inputs are empty
@@ -319,7 +316,15 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
319316
320317 # 1. Dequantize
321318 # 2. MatmulnN
322- output = torch .nn .functional .linear (A , F .dequantize_4bit (B , quant_state ).to (A .dtype ).t (), bias )
319+ # Use linear function which correctly handles 1D and 2D inputs
320+ result = torch .nn .functional .linear (A , F .dequantize_4bit (B , quant_state ).to (A .dtype ).t (), bias )
321+
322+ # If out is provided, resize it if necessary and copy the result
323+ if out is not None :
324+ if out .shape != result .shape :
325+ out .resize_ (result .shape )
326+ out .copy_ (result )
327+ result = out
323328
324329 # 3. Save state
325330 ctx .state = quant_state
@@ -330,7 +335,7 @@ def forward(ctx, A, B, out=None, bias=None, quant_state: Optional[F.QuantState]
330335 else :
331336 ctx .tensors = (None , None )
332337
333- return output
338+ return result
334339
335340 @staticmethod
336341 def backward (ctx , grad_output ):
@@ -385,9 +390,14 @@ def matmul_4bit(
385390 )
386391 return MatMul4Bit .apply (A , B , out , bias , quant_state )
387392 else :
388- out = F .gemv_4bit (A , B .t (), out , state = quant_state )
393+ # For 1D case, we'll use the MatMul4Bit implementation which correctly handles out parameter
394+ if out is not None and A .dim () == 1 :
395+ return MatMul4Bit .apply (A , B , out , bias , quant_state )
396+
397+ # For other cases, use gemv_4bit
398+ result = F .gemv_4bit (A , B .t (), out , state = quant_state )
389399 if bias is not None :
390- out += bias
391- return out
400+ result += bias
401+ return result
392402 else :
393403 return MatMul4Bit .apply (A , B , out , bias , quant_state )
0 commit comments