diff --git a/src/peft/tuners/poly/layer.py b/src/peft/tuners/poly/layer.py index 13ee0d9e06..f946017186 100644 --- a/src/peft/tuners/poly/layer.py +++ b/src/peft/tuners/poly/layer.py @@ -154,7 +154,7 @@ def forward(self, x: torch.Tensor, *args: Any, task_ids: torch.Tensor = None, ** A = A.reshape(bs, self.in_features, r) B = B.transpose(1, 2).reshape(bs, r, self.out_features) - x = x.to(A.dtype) + x = self._cast_input_dtype(x, A.dtype) result += x.bmm(A).bmm(B) / r result = result.to(previous_dtype)