Skip to content

Commit c4facab

Browse files
committed
Merge branch 'add-neuron-backend' of github.com:JingyaHuang/diffusers into add-neuron-backend
2 parents 3bb9c7c + da79308 commit c4facab

6 files changed

Lines changed: 141 additions & 23 deletions

File tree

.github/workflows/pr_dependency_test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ on:
66
- main
77
paths:
88
- "src/diffusers/**.py"
9+
- "tests/**.py"
910
push:
1011
branches:
1112
- main

.github/workflows/pr_torch_dependency_test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ on:
66
- main
77
paths:
88
- "src/diffusers/**.py"
9+
- "tests/**.py"
910
push:
1011
branches:
1112
- main
@@ -26,7 +27,7 @@ jobs:
2627
- name: Install dependencies
2728
run: |
2829
pip install -e .
29-
pip install torch torchvision torchaudio pytest
30+
pip install torch pytest
3031
- name: Check for soft dependencies
3132
run: |
3233
pytest tests/others/test_dependencies.py

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,68 @@ def __getitem__(self, index):
906906
return example
907907

908908

909+
# These helpers only matter for prior preservation, where instance and class prompt
910+
# embedding batches are concatenated and may not share the same mask/sequence length.
911+
def _materialize_prompt_embedding_mask(
912+
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None
913+
) -> torch.Tensor:
914+
"""Return a dense mask tensor for a prompt embedding batch."""
915+
batch_size, seq_len = prompt_embeds.shape[:2]
916+
917+
if prompt_embeds_mask is None:
918+
return torch.ones((batch_size, seq_len), dtype=torch.long, device=prompt_embeds.device)
919+
920+
if prompt_embeds_mask.shape != (batch_size, seq_len):
921+
raise ValueError(
922+
f"`prompt_embeds_mask` shape {prompt_embeds_mask.shape} must match prompt embeddings shape "
923+
f"({batch_size}, {seq_len})."
924+
)
925+
926+
return prompt_embeds_mask.to(device=prompt_embeds.device)
927+
928+
929+
def _pad_prompt_embedding_pair(
930+
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None, target_seq_len: int
931+
) -> tuple[torch.Tensor, torch.Tensor]:
932+
"""Pad one prompt embedding batch and its mask to a shared sequence length."""
933+
prompt_embeds_mask = _materialize_prompt_embedding_mask(prompt_embeds, prompt_embeds_mask)
934+
pad_width = target_seq_len - prompt_embeds.shape[1]
935+
936+
if pad_width <= 0:
937+
return prompt_embeds, prompt_embeds_mask
938+
939+
prompt_embeds = torch.cat(
940+
[prompt_embeds, prompt_embeds.new_zeros(prompt_embeds.shape[0], pad_width, prompt_embeds.shape[2])], dim=1
941+
)
942+
prompt_embeds_mask = torch.cat(
943+
[prompt_embeds_mask, prompt_embeds_mask.new_zeros(prompt_embeds_mask.shape[0], pad_width)], dim=1
944+
)
945+
946+
return prompt_embeds, prompt_embeds_mask
947+
948+
949+
def concat_prompt_embedding_batches(
950+
*prompt_embedding_pairs: tuple[torch.Tensor, torch.Tensor | None],
951+
) -> tuple[torch.Tensor, torch.Tensor | None]:
952+
"""Concatenate prompt embedding batches while handling missing masks and length mismatches."""
953+
if not prompt_embedding_pairs:
954+
raise ValueError("At least one prompt embedding pair must be provided.")
955+
956+
target_seq_len = max(prompt_embeds.shape[1] for prompt_embeds, _ in prompt_embedding_pairs)
957+
padded_pairs = [
958+
_pad_prompt_embedding_pair(prompt_embeds, prompt_embeds_mask, target_seq_len)
959+
for prompt_embeds, prompt_embeds_mask in prompt_embedding_pairs
960+
]
961+
962+
merged_prompt_embeds = torch.cat([prompt_embeds for prompt_embeds, _ in padded_pairs], dim=0)
963+
merged_mask = torch.cat([prompt_embeds_mask for _, prompt_embeds_mask in padded_pairs], dim=0)
964+
965+
if merged_mask.all():
966+
return merged_prompt_embeds, None
967+
968+
return merged_prompt_embeds, merged_mask
969+
970+
909971
def main(args):
910972
if args.report_to == "wandb" and args.hub_token is not None:
911973
raise ValueError(
@@ -1320,8 +1382,10 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
13201382
prompt_embeds = instance_prompt_embeds
13211383
prompt_embeds_mask = instance_prompt_embeds_mask
13221384
if args.with_prior_preservation:
1323-
prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)
1324-
prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)
1385+
prompt_embeds, prompt_embeds_mask = concat_prompt_embedding_batches(
1386+
(instance_prompt_embeds, instance_prompt_embeds_mask),
1387+
(class_prompt_embeds, class_prompt_embeds_mask),
1388+
)
13251389

13261390
# if cache_latents is set to True, we encode images to latents and store them.
13271391
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
@@ -1465,7 +1529,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14651529
prompt_embeds = prompt_embeds_cache[step]
14661530
prompt_embeds_mask = prompt_embeds_mask_cache[step]
14671531
else:
1468-
num_repeat_elements = len(prompts)
1532+
# With prior preservation, prompt_embeds already contains [instance, class] embeddings
1533+
# from the cat above, but collate_fn also doubles the prompts list. Use half the
1534+
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
1535+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
14691536
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
14701537
if prompt_embeds_mask is not None:
14711538
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,11 @@ def rope_params(self, index, dim, theta=10000):
233233
freqs = torch.polar(torch.ones_like(freqs), freqs)
234234
return freqs
235235

236+
@lru_cache_unless_export(maxsize=None)
237+
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
238+
"""Return pos_freqs and neg_freqs on the given device."""
239+
return self.pos_freqs.to(device), self.neg_freqs.to(device)
240+
236241
def forward(
237242
self,
238243
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
@@ -300,8 +305,9 @@ def forward(
300305
max_vid_index = max(height, width, max_vid_index)
301306

302307
max_txt_seq_len_int = int(max_txt_seq_len)
303-
# Create device-specific copy for text freqs without modifying self.pos_freqs
304-
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
308+
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
309+
pos_freqs_device, _ = self._get_device_freqs(device)
310+
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
305311
vid_freqs = torch.cat(vid_freqs, dim=0)
306312

307313
return vid_freqs, txt_freqs
@@ -311,8 +317,9 @@ def _compute_video_freqs(
311317
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
312318
) -> torch.Tensor:
313319
seq_lens = frame * height * width
314-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
315-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
320+
pos_freqs, neg_freqs = (
321+
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
322+
)
316323

317324
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
318325
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -367,6 +374,11 @@ def rope_params(self, index, dim, theta=10000):
367374
freqs = torch.polar(torch.ones_like(freqs), freqs)
368375
return freqs
369376

377+
@lru_cache_unless_export(maxsize=None)
378+
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
379+
"""Return pos_freqs and neg_freqs on the given device."""
380+
return self.pos_freqs.to(device), self.neg_freqs.to(device)
381+
370382
def forward(
371383
self,
372384
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
@@ -421,17 +433,19 @@ def forward(
421433

422434
max_vid_index = max(max_vid_index, layer_num)
423435
max_txt_seq_len_int = int(max_txt_seq_len)
424-
# Create device-specific copy for text freqs without modifying self.pos_freqs
425-
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
436+
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
437+
pos_freqs_device, _ = self._get_device_freqs(device)
438+
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
426439
vid_freqs = torch.cat(vid_freqs, dim=0)
427440

428441
return vid_freqs, txt_freqs
429442

430443
@lru_cache_unless_export(maxsize=None)
431444
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
432445
seq_lens = frame * height * width
433-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
434-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
446+
pos_freqs, neg_freqs = (
447+
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
448+
)
435449

436450
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
437451
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -452,8 +466,9 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device
452466
@lru_cache_unless_export(maxsize=None)
453467
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
454468
seq_lens = frame * height * width
455-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
456-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
469+
pos_freqs, neg_freqs = (
470+
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
471+
)
457472

458473
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
459474
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

src/diffusers/pipelines/consisid/consisid_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,13 @@
55
import numpy as np
66
import torch
77
from PIL import Image, ImageOps
8-
from torchvision.transforms import InterpolationMode
9-
from torchvision.transforms.functional import normalize, resize
108

11-
from ...utils import get_logger, load_image
9+
from ...utils import get_logger, is_torchvision_available, load_image
10+
11+
12+
if is_torchvision_available():
13+
from torchvision.transforms import InterpolationMode
14+
from torchvision.transforms.functional import normalize, resize
1215

1316

1417
logger = get_logger(__name__)

tests/others/test_dependencies.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,14 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
import unittest
1716
from importlib import import_module
1817

18+
import pytest
1919

20-
class DependencyTester(unittest.TestCase):
20+
21+
class TestDependencies:
2122
def test_diffusers_import(self):
22-
try:
23-
import diffusers # noqa: F401
24-
except ImportError:
25-
assert False
23+
import diffusers # noqa: F401
2624

2725
def test_backend_registration(self):
2826
import diffusers
@@ -52,3 +50,36 @@ def test_pipeline_imports(self):
5250
if hasattr(diffusers.pipelines, cls_name):
5351
pipeline_folder_module = ".".join(str(cls_module.__module__).split(".")[:3])
5452
_ = import_module(pipeline_folder_module, str(cls_name))
53+
54+
def test_pipeline_module_imports(self):
55+
"""Import every pipeline submodule whose dependencies are satisfied,
56+
to catch unguarded optional-dep imports (e.g., torchvision).
57+
58+
Uses inspect.getmembers to discover classes that the lazy loader can
59+
actually resolve (same self-filtering as test_pipeline_imports), then
60+
imports the full module path instead of truncating to the folder level.
61+
"""
62+
import diffusers
63+
import diffusers.pipelines
64+
65+
failures = []
66+
all_classes = inspect.getmembers(diffusers, inspect.isclass)
67+
68+
for cls_name, cls_module in all_classes:
69+
if not hasattr(diffusers.pipelines, cls_name):
70+
continue
71+
if "dummy_" in cls_module.__module__:
72+
continue
73+
74+
full_module_path = cls_module.__module__
75+
try:
76+
import_module(full_module_path)
77+
except ImportError as e:
78+
failures.append(f"{full_module_path}: {e}")
79+
except Exception:
80+
# Non-import errors (e.g., missing config) are fine; we only
81+
# care about unguarded import statements.
82+
pass
83+
84+
if failures:
85+
pytest.fail("Unguarded optional-dependency imports found:\n" + "\n".join(failures))

0 commit comments

Comments
 (0)