fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf#13503
Open
Ricardo-M-L wants to merge 2 commits intohuggingface:mainfrom
Open
fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf#13503Ricardo-M-L wants to merge 2 commits intohuggingface:mainfrom
Ricardo-M-L wants to merge 2 commits intohuggingface:mainfrom
Conversation
`fourier_filter` already upcasts `bfloat16` inputs to `float32` before calling `torch.fft.fftn`, because PyTorch's FFT does not support bf16. The same is true for `float16`: depending on the PyTorch version, `fftn` either - produces the experimental `torch.complex32` (ComplexHalf) dtype and emits a `UserWarning: ComplexHalf support is experimental`, or - raises `RuntimeError: Unsupported dtype Half` outright. Both paths were reachable from FreeU with half-precision models (e.g. `sd-turbo` + `fp16` + `enable_freeu`) as reported in huggingface#12504. Extend the existing upcast branch to cover `float16` too. The function already downcasts the result back to `x_in.dtype` at the end, so the externally observable dtype is unchanged. Closes huggingface#12504.
sayakpaul
approved these changes
Apr 19, 2026
| # fftn does not support bfloat16, and produces the experimental ComplexHalf | ||
| # dtype (torch.complex32) when given float16, which is numerically unstable | ||
| # and triggers a UserWarning. | ||
| elif x.dtype in (torch.bfloat16, torch.float16): |
Member
There was a problem hiding this comment.
Suggested change
| elif x.dtype in (torch.bfloat16, torch.float16): | |
| elif x.dtype != torch.float32: |
- Apply @sayakpaul's suggestion: use `elif x.dtype != torch.float32:` so any non-float32 dtype (bf16, fp16, and future half-precision dtypes) is upcast to float32 before the FFT. - Drop the `"torch.Tensor"` return annotation on the test helper that triggered ruff F821 in CI (torch is imported inside the method body, not at module scope).
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Closes #12504.
fourier_filter(the FFT helper used byenable_freeu) already upcastsbfloat16inputs tofloat32before callingtorch.fft.fftn, becausePyTorch's FFT does not support bf16. The same is true for
float16:depending on the PyTorch version,
fftneithertorch.complex32(ComplexHalf) dtype andemits a
UserWarning: ComplexHalf support is experimental…(theoriginal symptom in freeU ComplexHalf warning #12504), or
RuntimeError: Unsupported dtype Halfoutright (reproduced onCPU with
torch==2.9.0).Both paths were reachable from FreeU with half-precision models
(
sd-turbo+torch_dtype=torch.float16+enable_freeu(…)).Fix
Extend the existing upcast branch to cover
float16as well asbfloat16. The function already downcasts back tox_in.dtypeat theend, so the externally observable dtype is unchanged.
Running FFT in fp32 and downcasting is also more numerically accurate
than staying in fp16/complex32.
Tests
Added
tests/others/test_utils.py::FourierFilterTesterwith threecases:
test_fourier_filter_float16_no_complexhalf_warning— fp16 inputproduces no
ComplexHalfwarning and returns an fp16 tensor.test_fourier_filter_bfloat16_no_complexhalf_warning— regressionguard for the existing bf16 path.
test_fourier_filter_preserves_dtype_and_shape— verifies dtype andshape round-trip for all three supported dtypes.
Locally on
torch==2.9.0+cpu, these tests fail without the patch(RuntimeError for fp16) and pass with it.
Why not PR #12511?
PR #12511 addresses the same warning by locally suppressing it with
warnings.catch_warnings()aroundfftn/ifftn. That hides thewarning but keeps the FFT running on ComplexHalf (or raising on
platforms where fp16 FFT is unsupported). This PR instead removes the
root cause: the FFT is run in fp32, matching the existing bf16 handling.
Before submitting
Who can review?
@sayakpaul @yiyixuxu
AI assistance (Claude) was used to draft this patch. I reviewed every
line and ran the tests locally.