Skip to content

Commit d74877c

Browse files
committed
Fix model utility edge cases
1 parent 48f39c2 commit d74877c

15 files changed

Lines changed: 311 additions & 64 deletions

src/diffusers/models/attention_processor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -324,12 +324,12 @@ def set_use_xla_flash_attention(
324324
Specify the partition specification if using SPMD. Otherwise None.
325325
"""
326326
if use_xla_flash_attention:
327-
if not is_torch_xla_available:
328-
raise "torch_xla is not available"
327+
if not is_torch_xla_available():
328+
raise ImportError("torch_xla is not available")
329329
elif is_torch_xla_version("<", "2.3"):
330-
raise "flash attention pallas kernel is supported from torch_xla version 2.3"
330+
raise ImportError("flash attention pallas kernel is supported from torch_xla version 2.3")
331331
elif is_spmd() and is_torch_xla_version("<", "2.4"):
332-
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
332+
raise ImportError("flash attention pallas kernel using SPMD is supported from torch_xla version 2.4")
333333
else:
334334
if is_flux:
335335
processor = XLAFluxFlashAttnProcessor2_0(partition_spec)

src/diffusers/models/auto_model.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,10 +275,17 @@ def from_pretrained(cls, pretrained_model_or_path: str | os.PathLike | None = No
275275
library = None
276276
orig_class_name = None
277277

278+
def load_config_with_name(config_name, *args, **kwargs):
279+
original_config_name = cls.config_name
280+
try:
281+
cls.config_name = config_name
282+
return cls.load_config(*args, **kwargs)
283+
finally:
284+
cls.config_name = original_config_name
285+
278286
# Always attempt to fetch model_index.json first
279287
try:
280-
cls.config_name = "model_index.json"
281-
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
288+
config = load_config_with_name("model_index.json", pretrained_model_or_path, **load_config_kwargs)
282289

283290
if subfolder is not None and subfolder in config:
284291
library, orig_class_name = config[subfolder]
@@ -289,8 +296,9 @@ def from_pretrained(cls, pretrained_model_or_path: str | os.PathLike | None = No
289296

290297
# Unable to load from model_index.json so fallback to loading from config
291298
if library is None and orig_class_name is None:
292-
cls.config_name = "config.json"
293-
config = cls.load_config(pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs)
299+
config = load_config_with_name(
300+
"config.json", pretrained_model_or_path, subfolder=subfolder, **load_config_kwargs
301+
)
294302

295303
if "_class_name" in config:
296304
# If we find a class name in the config, we can try to load the model as a diffusers model
@@ -342,7 +350,7 @@ def from_pretrained(cls, pretrained_model_or_path: str | os.PathLike | None = No
342350

343351
load_id_kwargs = {"pretrained_model_name_or_path": pretrained_model_or_path, **kwargs}
344352
parts = [load_id_kwargs.get(field, "null") for field in DIFFUSERS_LOAD_ID_FIELDS]
345-
load_id = "|".join("null" if p is None else p for p in parts)
353+
load_id = "|".join("null" if p is None else str(p) for p in parts)
346354
model._diffusers_load_id = load_id
347355

348356
return model

src/diffusers/models/cache_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def cache_context(self, name: str):
159159
registry = HookRegistry.check_if_exists_or_initialize(self)
160160
registry._set_context(name)
161161

162-
yield
163-
164-
registry._set_context(None)
162+
try:
163+
yield
164+
finally:
165+
registry._set_context(None)

src/diffusers/models/downsampling.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from ..utils import deprecate
2020
from .normalization import RMSNorm
21-
from .upsampling import upfirdn2d_native
21+
from .upsampling import _prepare_fir_kernel, upfirdn2d_native
2222

2323

2424
class Downsample1D(nn.Module):
@@ -210,32 +210,29 @@ def _downsample_2d(
210210
"""
211211

212212
assert isinstance(factor, int) and factor >= 1
213-
if kernel is None:
214-
kernel = [1] * factor
215-
216-
# setup kernel
217-
kernel = torch.tensor(kernel, dtype=torch.float32)
218-
if kernel.ndim == 1:
219-
kernel = torch.outer(kernel, kernel)
220-
kernel /= torch.sum(kernel)
221-
222-
kernel = kernel * gain
213+
kernel = _prepare_fir_kernel(
214+
kernel,
215+
factor=factor,
216+
gain=gain,
217+
device=hidden_states.device,
218+
dtype=hidden_states.dtype,
219+
)
223220

224221
if self.use_conv:
225222
_, _, convH, convW = weight.shape
226223
pad_value = (kernel.shape[0] - factor) + (convW - 1)
227224
stride_value = [factor, factor]
228225
upfirdn_input = upfirdn2d_native(
229226
hidden_states,
230-
torch.tensor(kernel, device=hidden_states.device),
227+
kernel,
231228
pad=((pad_value + 1) // 2, pad_value // 2),
232229
)
233230
output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0)
234231
else:
235232
pad_value = kernel.shape[0] - factor
236233
output = upfirdn2d_native(
237234
hidden_states,
238-
torch.tensor(kernel, device=hidden_states.device),
235+
kernel,
239236
down=factor,
240237
pad=((pad_value + 1) // 2, pad_value // 2),
241238
)
@@ -380,19 +377,17 @@ def downsample_2d(
380377
"""
381378

382379
assert isinstance(factor, int) and factor >= 1
383-
if kernel is None:
384-
kernel = [1] * factor
385-
386-
kernel = torch.tensor(kernel, dtype=torch.float32)
387-
if kernel.ndim == 1:
388-
kernel = torch.outer(kernel, kernel)
389-
kernel /= torch.sum(kernel)
390-
391-
kernel = kernel * gain
380+
kernel = _prepare_fir_kernel(
381+
kernel,
382+
factor=factor,
383+
gain=gain,
384+
device=hidden_states.device,
385+
dtype=hidden_states.dtype,
386+
)
392387
pad_value = kernel.shape[0] - factor
393388
output = upfirdn2d_native(
394389
hidden_states,
395-
kernel.to(device=hidden_states.device),
390+
kernel,
396391
down=factor,
397392
pad=((pad_value + 1) // 2, pad_value // 2),
398393
)

src/diffusers/models/embeddings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
328328
output_type (`str`, *optional*, defaults to `"np"`): Output type. Use `"pt"` for PyTorch tensors.
329329
flip_sin_to_cos (`bool`, *optional*, defaults to `False`): Whether to flip sine and cosine embeddings.
330330
dtype (`torch.dtype`, *optional*): Data type for frequency calculations. If `None`, defaults to
331-
`torch.float32` on MPS devices (which don't support `torch.float64`) and `torch.float64` on other devices.
331+
`torch.float32`.
332332
333333
Returns:
334334
`torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
@@ -346,7 +346,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin
346346

347347
# Auto-detect appropriate dtype if not specified
348348
if dtype is None:
349-
dtype = torch.float32 if pos.device.type == "mps" else torch.float64
349+
dtype = torch.float32
350350

351351
omega = torch.arange(embed_dim // 2, device=pos.device, dtype=dtype)
352352
omega /= embed_dim / 2.0

src/diffusers/models/modeling_flax_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def from_pretrained(
318318
subfolder=subfolder,
319319
**kwargs,
320320
)
321+
else:
322+
unused_kwargs = kwargs
321323

322324
model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
323325

src/diffusers/models/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1592,7 +1592,7 @@ def enable_parallelism(
15921592
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
15931593
)
15941594

1595-
if not torch.distributed.is_available() and not torch.distributed.is_initialized():
1595+
if not torch.distributed.is_available() or not torch.distributed.is_initialized():
15961596
raise RuntimeError(
15971597
"torch.distributed must be available and initialized before calling `enable_parallelism`."
15981598
)

src/diffusers/models/upsampling.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,27 @@
2121
from .normalization import RMSNorm
2222

2323

24+
def _prepare_fir_kernel(
25+
kernel: torch.Tensor | None,
26+
*,
27+
factor: int,
28+
gain: float,
29+
device: torch.device,
30+
dtype: torch.dtype,
31+
upsample: bool = False,
32+
) -> torch.Tensor:
33+
if kernel is None:
34+
kernel = [1] * factor
35+
36+
kernel = torch.as_tensor(kernel, device=device, dtype=torch.float32)
37+
if kernel.ndim == 1:
38+
kernel = torch.outer(kernel, kernel)
39+
kernel = kernel / torch.sum(kernel)
40+
41+
scale = gain * (factor**2) if upsample else gain
42+
return (kernel * scale).to(device=device, dtype=dtype)
43+
44+
2445
class Upsample1D(nn.Module):
2546
"""A 1D upsampling layer with an optional convolution.
2647
@@ -253,17 +274,14 @@ def _upsample_2d(
253274

254275
assert isinstance(factor, int) and factor >= 1
255276

256-
# Setup filter kernel.
257-
if kernel is None:
258-
kernel = [1] * factor
259-
260-
# setup kernel
261-
kernel = torch.tensor(kernel, dtype=torch.float32)
262-
if kernel.ndim == 1:
263-
kernel = torch.outer(kernel, kernel)
264-
kernel /= torch.sum(kernel)
265-
266-
kernel = kernel * (gain * (factor**2))
277+
kernel = _prepare_fir_kernel(
278+
kernel,
279+
factor=factor,
280+
gain=gain,
281+
device=hidden_states.device,
282+
dtype=hidden_states.dtype,
283+
upsample=True,
284+
)
267285

268286
if self.use_conv:
269287
convH = weight.shape[2]
@@ -300,14 +318,14 @@ def _upsample_2d(
300318

301319
output = upfirdn2d_native(
302320
inverse_conv,
303-
torch.tensor(kernel, device=inverse_conv.device),
321+
kernel.to(device=inverse_conv.device, dtype=inverse_conv.dtype),
304322
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2 + 1),
305323
)
306324
else:
307325
pad_value = kernel.shape[0] - factor
308326
output = upfirdn2d_native(
309327
hidden_states,
310-
torch.tensor(kernel, device=hidden_states.device),
328+
kernel,
311329
up=factor,
312330
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
313331
)
@@ -496,19 +514,18 @@ def upsample_2d(
496514
Tensor of the shape `[N, C, H * factor, W * factor]`
497515
"""
498516
assert isinstance(factor, int) and factor >= 1
499-
if kernel is None:
500-
kernel = [1] * factor
501-
502-
kernel = torch.tensor(kernel, dtype=torch.float32)
503-
if kernel.ndim == 1:
504-
kernel = torch.outer(kernel, kernel)
505-
kernel /= torch.sum(kernel)
506-
507-
kernel = kernel * (gain * (factor**2))
517+
kernel = _prepare_fir_kernel(
518+
kernel,
519+
factor=factor,
520+
gain=gain,
521+
device=hidden_states.device,
522+
dtype=hidden_states.dtype,
523+
upsample=True,
524+
)
508525
pad_value = kernel.shape[0] - factor
509526
output = upfirdn2d_native(
510527
hidden_states,
511-
kernel.to(device=hidden_states.device),
528+
kernel,
512529
up=factor,
513530
pad=((pad_value + 1) // 2 + factor - 1, pad_value // 2),
514531
)

src/diffusers/utils/import_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def is_torch_xla_version(operation: str, version: str):
708708
version (`str`):
709709
A string version of torch_xla
710710
"""
711-
if not is_torch_xla_available:
711+
if not is_torch_xla_available():
712712
return False
713713
return compare_versions(parse(_torch_xla_version), operation, version)
714714

tests/hooks/test_hooks.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
import torch
1919

2020
from diffusers.hooks import HookRegistry, ModelHook
21+
from diffusers.hooks.hooks import BaseState, StateManager
22+
from diffusers.models.cache_utils import CacheMixin
2123
from diffusers.training_utils import free_memory
2224
from diffusers.utils.logging import get_logger
2325

@@ -114,6 +116,27 @@ def reset_state(self, module):
114116
self.increment = 0
115117

116118

119+
class CacheTestState(BaseState):
120+
def reset(self):
121+
pass
122+
123+
124+
class CacheTestHook(ModelHook):
125+
_is_stateful = True
126+
127+
def __init__(self):
128+
super().__init__()
129+
self.state_manager = StateManager(CacheTestState)
130+
131+
def reset_state(self, module):
132+
self.state_manager.reset()
133+
134+
135+
class CacheContextModel(torch.nn.Module, CacheMixin):
136+
def forward(self, x: torch.Tensor) -> torch.Tensor:
137+
return x
138+
139+
117140
class SkipLayerHook(ModelHook):
118141
def __init__(self, skip_layer: bool):
119142
super().__init__()
@@ -338,6 +361,20 @@ def test_invocation_order_stateful_middle(self):
338361
)
339362
self.assertEqual(output, expected_invocation_order_log)
340363

364+
def test_cache_context_clears_stateful_hook_context_after_exception(self):
365+
model = CacheContextModel()
366+
hook = CacheTestHook()
367+
HookRegistry.check_if_exists_or_initialize(model).register_hook(hook, "cache_test")
368+
369+
with self.assertRaisesRegex(RuntimeError, "interrupted"):
370+
with model.cache_context("failed-call"):
371+
hook.state_manager.get_state()
372+
raise RuntimeError("interrupted")
373+
374+
self.assertIsNone(hook.state_manager._current_context)
375+
with self.assertRaisesRegex(ValueError, "No context is set"):
376+
hook.state_manager.get_state()
377+
341378
def test_invocation_order_stateful_last(self):
342379
registry = HookRegistry.check_if_exists_or_initialize(self.model)
343380
registry.register_hook(AddHook(1), "add_hook")

0 commit comments

Comments
 (0)