Skip to content

fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf#13503

Open
Ricardo-M-L wants to merge 2 commits intohuggingface:mainfrom
Ricardo-M-L:fix/fourier-filter-fp16
Open

fix(freeu): run FFT in float32 for float16 inputs to avoid ComplexHalf#13503
Ricardo-M-L wants to merge 2 commits intohuggingface:mainfrom
Ricardo-M-L:fix/fourier-filter-fp16

Conversation

@Ricardo-M-L
Copy link
Copy Markdown

What does this PR do?

Closes #12504.

fourier_filter (the FFT helper used by enable_freeu) already upcasts
bfloat16 inputs to float32 before calling torch.fft.fftn, because
PyTorch's FFT does not support bf16. The same is true for float16:
depending on the PyTorch version, fftn either

  • produces the experimental torch.complex32 (ComplexHalf) dtype and
    emits a UserWarning: ComplexHalf support is experimental… (the
    original symptom in freeU ComplexHalf warning #12504), or
  • raises RuntimeError: Unsupported dtype Half outright (reproduced on
    CPU with torch==2.9.0).

Both paths were reachable from FreeU with half-precision models
(sd-turbo + torch_dtype=torch.float16 + enable_freeu(…)).

Fix

Extend the existing upcast branch to cover float16 as well as
bfloat16. The function already downcasts back to x_in.dtype at the
end, so the externally observable dtype is unchanged.

    # Non-power of 2 images must be float32
    if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
        x = x.to(dtype=torch.float32)
    # fftn does not support bfloat16, and produces the experimental ComplexHalf
    # dtype (torch.complex32) when given float16, which is numerically unstable
    # and triggers a UserWarning.
    elif x.dtype in (torch.bfloat16, torch.float16):
        x = x.to(dtype=torch.float32)

Running FFT in fp32 and downcasting is also more numerically accurate
than staying in fp16/complex32.

Tests

Added tests/others/test_utils.py::FourierFilterTester with three
cases:

  • test_fourier_filter_float16_no_complexhalf_warning — fp16 input
    produces no ComplexHalf warning and returns an fp16 tensor.
  • test_fourier_filter_bfloat16_no_complexhalf_warning — regression
    guard for the existing bf16 path.
  • test_fourier_filter_preserves_dtype_and_shape — verifies dtype and
    shape round-trip for all three supported dtypes.

Locally on torch==2.9.0+cpu, these tests fail without the patch
(RuntimeError for fp16) and pass with it.

tests/others/test_utils.py::FourierFilterTester::test_fourier_filter_bfloat16_no_complexhalf_warning PASSED
tests/others/test_utils.py::FourierFilterTester::test_fourier_filter_float16_no_complexhalf_warning PASSED
tests/others/test_utils.py::FourierFilterTester::test_fourier_filter_preserves_dtype_and_shape PASSED

Why not PR #12511?

PR #12511 addresses the same warning by locally suppressing it with
warnings.catch_warnings() around fftn/ifftn. That hides the
warning but keeps the FFT running on ComplexHalf (or raising on
platforms where fp16 FFT is unsupported). This PR instead removes the
root cause: the FFT is run in fp32, matching the existing bf16 handling.

Before submitting

Who can review?

@sayakpaul @yiyixuxu

AI assistance (Claude) was used to draft this patch. I reviewed every
line and ran the tests locally.

`fourier_filter` already upcasts `bfloat16` inputs to `float32` before
calling `torch.fft.fftn`, because PyTorch's FFT does not support bf16.
The same is true for `float16`: depending on the PyTorch version,
`fftn` either

- produces the experimental `torch.complex32` (ComplexHalf) dtype and
  emits a `UserWarning: ComplexHalf support is experimental`, or
- raises `RuntimeError: Unsupported dtype Half` outright.

Both paths were reachable from FreeU with half-precision models
(e.g. `sd-turbo` + `fp16` + `enable_freeu`) as reported in huggingface#12504.

Extend the existing upcast branch to cover `float16` too. The function
already downcasts the result back to `x_in.dtype` at the end, so the
externally observable dtype is unchanged.

Closes huggingface#12504.
@github-actions github-actions bot added tests utils size/S PR with diff < 50 LOC labels Apr 19, 2026
Comment thread src/diffusers/utils/torch_utils.py Outdated
# fftn does not support bfloat16, and produces the experimental ComplexHalf
# dtype (torch.complex32) when given float16, which is numerically unstable
# and triggers a UserWarning.
elif x.dtype in (torch.bfloat16, torch.float16):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif x.dtype in (torch.bfloat16, torch.float16):
elif x.dtype != torch.float32:

- Apply @sayakpaul's suggestion: use `elif x.dtype != torch.float32:`
  so any non-float32 dtype (bf16, fp16, and future half-precision
  dtypes) is upcast to float32 before the FFT.
- Drop the `"torch.Tensor"` return annotation on the test helper
  that triggered ruff F821 in CI (torch is imported inside the
  method body, not at module scope).
@github-actions github-actions bot added size/S PR with diff < 50 LOC and removed size/S PR with diff < 50 LOC labels Apr 19, 2026
@HuggingFaceDocBuilderDev
Copy link
Copy Markdown

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size/S PR with diff < 50 LOC tests utils

Projects

None yet

Development

Successfully merging this pull request may close these issues.

freeU ComplexHalf warning

3 participants