Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/diffusers/utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,10 @@ def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.T
# Non-power of 2 images must be float32
if (W & (W - 1)) != 0 or (H & (H - 1)) != 0:
x = x.to(dtype=torch.float32)
# fftn does not support bfloat16
elif x.dtype == torch.bfloat16:
# fftn does not support bfloat16, and produces the experimental ComplexHalf
# dtype (torch.complex32) when given float16, which is numerically unstable
# and triggers a UserWarning.
elif x.dtype in (torch.bfloat16, torch.float16):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif x.dtype in (torch.bfloat16, torch.float16):
elif x.dtype != torch.float32:

x = x.to(dtype=torch.float32)

# FFT
Expand Down
43 changes: 43 additions & 0 deletions tests/others/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,49 @@ def test_deprecate_testing_utils_module(self):
), f"Expected deprecation message substring not found, got: {messages}"


class FourierFilterTester(unittest.TestCase):
"""Tests for :func:`diffusers.utils.torch_utils.fourier_filter` (FreeU helper)."""

def _run_without_complexhalf_warning(self, dtype) -> "torch.Tensor":
import torch

from diffusers.utils.torch_utils import fourier_filter

x = torch.randn(1, 4, 32, 32, dtype=dtype)
with warnings.catch_warnings(record=True) as caught:
warnings.simplefilter("always")
out = fourier_filter(x, threshold=1, scale=0.5)

messages = [str(w.message) for w in caught]
assert not any("ComplexHalf" in m for m in messages), (
f"Unexpected ComplexHalf warning emitted by fourier_filter: {messages}"
)
return out

def test_fourier_filter_float16_no_complexhalf_warning(self):
import torch

out = self._run_without_complexhalf_warning(torch.float16)
assert out.dtype == torch.float16

def test_fourier_filter_bfloat16_no_complexhalf_warning(self):
import torch

out = self._run_without_complexhalf_warning(torch.bfloat16)
assert out.dtype == torch.bfloat16

def test_fourier_filter_preserves_dtype_and_shape(self):
import torch

from diffusers.utils.torch_utils import fourier_filter

for dtype in (torch.float32, torch.float16, torch.bfloat16):
x = torch.randn(2, 3, 16, 16, dtype=dtype)
out = fourier_filter(x, threshold=1, scale=0.5)
assert out.dtype == dtype
assert out.shape == x.shape


# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
class ExpectationsTester(unittest.TestCase):
def test_expectations(self):
Expand Down
Loading