From 98772121bd1f0acb9a0ca9652932d076cfe0328b Mon Sep 17 00:00:00 2001 From: Taksh Date: Mon, 20 Apr 2026 14:29:29 +0530 Subject: [PATCH] Use _cast_input_dtype in poly Linear.forward The raw x.to(A.dtype) cast bypasses the disable_lora_input_dtype_casting context manager. Switch to the BaseTunerLayer._cast_input_dtype helper so poly respects the same casting controls as other tuners (fourierft, vera, etc.). --- src/peft/tuners/poly/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/peft/tuners/poly/layer.py b/src/peft/tuners/poly/layer.py index 13ee0d9e06..f946017186 100644 --- a/src/peft/tuners/poly/layer.py +++ b/src/peft/tuners/poly/layer.py @@ -154,7 +154,7 @@ def forward(self, x: torch.Tensor, *args: Any, task_ids: torch.Tensor = None, ** A = A.reshape(bs, self.in_features, r) B = B.transpose(1, 2).reshape(bs, r, self.out_features) - x = x.to(A.dtype) + x = self._cast_input_dtype(x, A.dtype) result += x.bmm(A).bmm(B) / r result = result.to(previous_dtype)