Skip to content

Commit ff90afd

Browse files
authored
Merge branch 'main' into xla-autoencoder-qwenimage
2 parents b0b8008 + 71a6fd9 commit ff90afd

3 files changed

Lines changed: 14 additions & 12 deletions

File tree

.github/workflows/upload_pr_documentation.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ on:
88

99
jobs:
1010
build:
11-
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
11+
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@9ad2de8582b56c017cb530c1165116d40433f1c6 # main
1212
with:
1313
package_name: diffusers
1414
secrets:

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -777,7 +777,8 @@ def _prepare_sequence(
777777

778778
# Pad token
779779
feats_cat = torch.cat(feats, dim=0)
780-
feats_cat[torch.cat(inner_pad_mask)] = pad_token
780+
mask = torch.cat(inner_pad_mask).unsqueeze(-1)
781+
feats_cat = torch.where(mask, pad_token, feats_cat)
781782
feats = list(feats_cat.split(item_seqlens, dim=0))
782783

783784
# RoPE

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,15 @@ def __call__(
486486
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
487487
self._num_timesteps = len(timesteps)
488488

489+
# We set the index here to remove DtoH sync, helpful especially during compilation.
490+
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
491+
self.scheduler.set_begin_index(0)
492+
493+
if self.do_classifier_free_guidance and self._cfg_truncation is not None and float(self._cfg_truncation) <= 1:
494+
_precomputed_t_norms = ((1000 - timesteps.float()) / 1000).tolist()
495+
else:
496+
_precomputed_t_norms = None
497+
489498
# 6. Denoising loop
490499
with self.progress_bar(total=num_inference_steps) as progress_bar:
491500
for i, t in enumerate(timesteps):
@@ -495,17 +504,9 @@ def __call__(
495504
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
496505
timestep = t.expand(latents.shape[0])
497506
timestep = (1000 - timestep) / 1000
498-
# Normalized time for time-aware config (0 at start, 1 at end)
499-
t_norm = timestep[0].item()
500-
501-
# Handle cfg truncation
502507
current_guidance_scale = self.guidance_scale
503-
if (
504-
self.do_classifier_free_guidance
505-
and self._cfg_truncation is not None
506-
and float(self._cfg_truncation) <= 1
507-
):
508-
if t_norm > self._cfg_truncation:
508+
if _precomputed_t_norms is not None:
509+
if _precomputed_t_norms[i] > self._cfg_truncation:
509510
current_guidance_scale = 0.0
510511

511512
# Run CFG only if configured AND scale is non-zero

0 commit comments

Comments
 (0)