Skip to content

Commit 0e8e535

Browse files
committed
fix peft cpu ut failure
Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
1 parent 09ea861 commit 0e8e535

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

bitsandbytes/nn/modules.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -679,6 +679,7 @@ def to(self: T, tensor: Tensor, non_blocking: bool = ...) -> T: ...
679679
def to(self, *args, **kwargs):
680680
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(*args, **kwargs)
681681

682+
682683
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
683684
if device.type != "cpu" or self.data.dtype != torch.int8:
684685
return self._quantize(device)
@@ -690,8 +691,8 @@ def to(self, *args, **kwargs):
690691
requires_grad=self.requires_grad,
691692
has_fp16_weights=self.has_fp16_weights,
692693
)
693-
new_param.CB = self.CB
694-
new_param.SCB = self.SCB
694+
new_param.CB = self.CB.to(device=device) if self.CB != None else self.CB
695+
new_param.SCB = self.SCB.to(device=device) if self.SCB != None else self.SCB
695696

696697
return new_param
697698

0 commit comments

Comments
 (0)