Skip to content

Commit 68689e5

Browse files
authored
Merge branch 'huggingface:main' into add-neuron-backend
2 parents 28a5086 + 4548e68 commit 68689e5

2 files changed

Lines changed: 95 additions & 13 deletions

File tree

examples/dreambooth/train_dreambooth_lora_qwen_image.py

Lines changed: 70 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,68 @@ def __getitem__(self, index):
906906
return example
907907

908908

909+
# These helpers only matter for prior preservation, where instance and class prompt
910+
# embedding batches are concatenated and may not share the same mask/sequence length.
911+
def _materialize_prompt_embedding_mask(
912+
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None
913+
) -> torch.Tensor:
914+
"""Return a dense mask tensor for a prompt embedding batch."""
915+
batch_size, seq_len = prompt_embeds.shape[:2]
916+
917+
if prompt_embeds_mask is None:
918+
return torch.ones((batch_size, seq_len), dtype=torch.long, device=prompt_embeds.device)
919+
920+
if prompt_embeds_mask.shape != (batch_size, seq_len):
921+
raise ValueError(
922+
f"`prompt_embeds_mask` shape {prompt_embeds_mask.shape} must match prompt embeddings shape "
923+
f"({batch_size}, {seq_len})."
924+
)
925+
926+
return prompt_embeds_mask.to(device=prompt_embeds.device)
927+
928+
929+
def _pad_prompt_embedding_pair(
930+
prompt_embeds: torch.Tensor, prompt_embeds_mask: torch.Tensor | None, target_seq_len: int
931+
) -> tuple[torch.Tensor, torch.Tensor]:
932+
"""Pad one prompt embedding batch and its mask to a shared sequence length."""
933+
prompt_embeds_mask = _materialize_prompt_embedding_mask(prompt_embeds, prompt_embeds_mask)
934+
pad_width = target_seq_len - prompt_embeds.shape[1]
935+
936+
if pad_width <= 0:
937+
return prompt_embeds, prompt_embeds_mask
938+
939+
prompt_embeds = torch.cat(
940+
[prompt_embeds, prompt_embeds.new_zeros(prompt_embeds.shape[0], pad_width, prompt_embeds.shape[2])], dim=1
941+
)
942+
prompt_embeds_mask = torch.cat(
943+
[prompt_embeds_mask, prompt_embeds_mask.new_zeros(prompt_embeds_mask.shape[0], pad_width)], dim=1
944+
)
945+
946+
return prompt_embeds, prompt_embeds_mask
947+
948+
949+
def concat_prompt_embedding_batches(
950+
*prompt_embedding_pairs: tuple[torch.Tensor, torch.Tensor | None],
951+
) -> tuple[torch.Tensor, torch.Tensor | None]:
952+
"""Concatenate prompt embedding batches while handling missing masks and length mismatches."""
953+
if not prompt_embedding_pairs:
954+
raise ValueError("At least one prompt embedding pair must be provided.")
955+
956+
target_seq_len = max(prompt_embeds.shape[1] for prompt_embeds, _ in prompt_embedding_pairs)
957+
padded_pairs = [
958+
_pad_prompt_embedding_pair(prompt_embeds, prompt_embeds_mask, target_seq_len)
959+
for prompt_embeds, prompt_embeds_mask in prompt_embedding_pairs
960+
]
961+
962+
merged_prompt_embeds = torch.cat([prompt_embeds for prompt_embeds, _ in padded_pairs], dim=0)
963+
merged_mask = torch.cat([prompt_embeds_mask for _, prompt_embeds_mask in padded_pairs], dim=0)
964+
965+
if merged_mask.all():
966+
return merged_prompt_embeds, None
967+
968+
return merged_prompt_embeds, merged_mask
969+
970+
909971
def main(args):
910972
if args.report_to == "wandb" and args.hub_token is not None:
911973
raise ValueError(
@@ -1320,8 +1382,10 @@ def compute_text_embeddings(prompt, text_encoding_pipeline):
13201382
prompt_embeds = instance_prompt_embeds
13211383
prompt_embeds_mask = instance_prompt_embeds_mask
13221384
if args.with_prior_preservation:
1323-
prompt_embeds = torch.cat([prompt_embeds, class_prompt_embeds], dim=0)
1324-
prompt_embeds_mask = torch.cat([prompt_embeds_mask, class_prompt_embeds_mask], dim=0)
1385+
prompt_embeds, prompt_embeds_mask = concat_prompt_embedding_batches(
1386+
(instance_prompt_embeds, instance_prompt_embeds_mask),
1387+
(class_prompt_embeds, class_prompt_embeds_mask),
1388+
)
13251389

13261390
# if cache_latents is set to True, we encode images to latents and store them.
13271391
# Similar to pre-encoding in the case of a single instance prompt, if custom prompts are provided
@@ -1465,7 +1529,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
14651529
prompt_embeds = prompt_embeds_cache[step]
14661530
prompt_embeds_mask = prompt_embeds_mask_cache[step]
14671531
else:
1468-
num_repeat_elements = len(prompts)
1532+
# With prior preservation, prompt_embeds already contains [instance, class] embeddings
1533+
# from the cat above, but collate_fn also doubles the prompts list. Use half the
1534+
# prompts count to avoid a 2x over-repeat that produces more embeddings than latents.
1535+
num_repeat_elements = len(prompts) // 2 if args.with_prior_preservation else len(prompts)
14691536
prompt_embeds = prompt_embeds.repeat(num_repeat_elements, 1, 1)
14701537
if prompt_embeds_mask is not None:
14711538
prompt_embeds_mask = prompt_embeds_mask.repeat(num_repeat_elements, 1)

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,11 @@ def rope_params(self, index, dim, theta=10000):
233233
freqs = torch.polar(torch.ones_like(freqs), freqs)
234234
return freqs
235235

236+
@lru_cache_unless_export(maxsize=None)
237+
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
238+
"""Return pos_freqs and neg_freqs on the given device."""
239+
return self.pos_freqs.to(device), self.neg_freqs.to(device)
240+
236241
def forward(
237242
self,
238243
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
@@ -300,8 +305,9 @@ def forward(
300305
max_vid_index = max(height, width, max_vid_index)
301306

302307
max_txt_seq_len_int = int(max_txt_seq_len)
303-
# Create device-specific copy for text freqs without modifying self.pos_freqs
304-
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
308+
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
309+
pos_freqs_device, _ = self._get_device_freqs(device)
310+
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
305311
vid_freqs = torch.cat(vid_freqs, dim=0)
306312

307313
return vid_freqs, txt_freqs
@@ -311,8 +317,9 @@ def _compute_video_freqs(
311317
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
312318
) -> torch.Tensor:
313319
seq_lens = frame * height * width
314-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
315-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
320+
pos_freqs, neg_freqs = (
321+
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
322+
)
316323

317324
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
318325
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -367,6 +374,11 @@ def rope_params(self, index, dim, theta=10000):
367374
freqs = torch.polar(torch.ones_like(freqs), freqs)
368375
return freqs
369376

377+
@lru_cache_unless_export(maxsize=None)
378+
def _get_device_freqs(self, device: torch.device) -> tuple[torch.Tensor, torch.Tensor]:
379+
"""Return pos_freqs and neg_freqs on the given device."""
380+
return self.pos_freqs.to(device), self.neg_freqs.to(device)
381+
370382
def forward(
371383
self,
372384
video_fhw: tuple[int, int, int, list[tuple[int, int, int]]],
@@ -421,17 +433,19 @@ def forward(
421433

422434
max_vid_index = max(max_vid_index, layer_num)
423435
max_txt_seq_len_int = int(max_txt_seq_len)
424-
# Create device-specific copy for text freqs without modifying self.pos_freqs
425-
txt_freqs = self.pos_freqs.to(device)[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
436+
# Use cached device-transferred freqs to avoid CPU→GPU sync every forward call
437+
pos_freqs_device, _ = self._get_device_freqs(device)
438+
txt_freqs = pos_freqs_device[max_vid_index : max_vid_index + max_txt_seq_len_int, ...]
426439
vid_freqs = torch.cat(vid_freqs, dim=0)
427440

428441
return vid_freqs, txt_freqs
429442

430443
@lru_cache_unless_export(maxsize=None)
431444
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
432445
seq_lens = frame * height * width
433-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
434-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
446+
pos_freqs, neg_freqs = (
447+
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
448+
)
435449

436450
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
437451
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
@@ -452,8 +466,9 @@ def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device
452466
@lru_cache_unless_export(maxsize=None)
453467
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
454468
seq_lens = frame * height * width
455-
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
456-
neg_freqs = self.neg_freqs.to(device) if device is not None else self.neg_freqs
469+
pos_freqs, neg_freqs = (
470+
self._get_device_freqs(device) if device is not None else (self.pos_freqs, self.neg_freqs)
471+
)
457472

458473
freqs_pos = pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
459474
freqs_neg = neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)

0 commit comments

Comments
 (0)