Skip to content

Use _cast_input_dtype in poly Linear.forward#3177

Open
Chessing234 wants to merge 1 commit intohuggingface:mainfrom
Chessing234:fix/poly-use-cast-input-dtype
Open

Use _cast_input_dtype in poly Linear.forward#3177
Chessing234 wants to merge 1 commit intohuggingface:mainfrom
Chessing234:fix/poly-use-cast-input-dtype

Conversation

@Chessing234
Copy link
Copy Markdown
Contributor

Bug

poly.Linear.forward casts the input with a raw x.to(A.dtype) instead of
self._cast_input_dtype(...), bypassing the disable_lora_input_dtype_casting
context manager.

Root cause

BaseTunerLayer exposes _cast_input_dtype, which respects the module-level
flag for disabling input dtype casting. Most tuners use that helper, but
poly/layer.py still uses the raw .to() call.

Why the fix is correct

Change

src/peft/tuners/poly/layer.py: x = x.to(A.dtype)
x = self._cast_input_dtype(x, A.dtype).

The raw x.to(A.dtype) cast bypasses the disable_lora_input_dtype_casting
context manager. Switch to the BaseTunerLayer._cast_input_dtype helper so
poly respects the same casting controls as other tuners (fourierft, vera,
etc.).
@BenjaminBossan
Copy link
Copy Markdown
Member

@Chessing234 Thanks for the PR. Could you please check all PEFT methods for the same pattern and bundle all changes into a single PR instead of submitting one for each PEFT method individually? Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants