Skip to content

Commit cbc10e2

Browse files
committed
Fix redundant Z-Image terminal timestep
1 parent cbdedba commit cbc10e2

7 files changed

Lines changed: 50 additions & 8 deletions

File tree

src/diffusers/modular_pipelines/z_image/before_denoise.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ def retrieve_timesteps(
185185
return timesteps, num_inference_steps
186186

187187

188+
# Copied from diffusers.pipelines.z_image.pipeline_z_image.get_default_z_image_sigmas
189+
def get_default_z_image_sigmas(num_inference_steps: int) -> list[float]:
190+
return torch.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].tolist()
191+
192+
188193
class ZImageTextInputStep(ModularPipelineBlocks):
189194
model_name = "z-image"
190195

@@ -535,13 +540,15 @@ def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> P
535540
base_shift=components.scheduler.config.get("base_shift", 0.5),
536541
max_shift=components.scheduler.config.get("max_shift", 1.15),
537542
)
538-
components.scheduler.sigma_min = 0.0
543+
sigmas = block_state.sigmas
544+
if sigmas is None:
545+
sigmas = get_default_z_image_sigmas(block_state.num_inference_steps)
539546

540547
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
541548
components.scheduler,
542549
block_state.num_inference_steps,
543550
device,
544-
sigmas=block_state.sigmas,
551+
sigmas=sigmas,
545552
mu=mu,
546553
)
547554

src/diffusers/pipelines/z_image/pipeline_z_image.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ def retrieve_timesteps(
134134
return timesteps, num_inference_steps
135135

136136

137+
def get_default_z_image_sigmas(num_inference_steps: int) -> list[float]:
138+
return torch.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].tolist()
139+
140+
137141
class ZImagePipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
138142
model_cpu_offload_seq = "text_encoder->transformer->vae"
139143
_optional_components = []
@@ -474,7 +478,8 @@ def __call__(
474478
self.scheduler.config.get("base_shift", 0.5),
475479
self.scheduler.config.get("max_shift", 1.15),
476480
)
477-
self.scheduler.sigma_min = 0.0
481+
if sigmas is None:
482+
sigmas = get_default_z_image_sigmas(num_inference_steps)
478483
scheduler_kwargs = {"mu": mu}
479484
timesteps, num_inference_steps = retrieve_timesteps(
480485
self.scheduler,

src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ def retrieve_timesteps(
185185
return timesteps, num_inference_steps
186186

187187

188+
# Copied from diffusers.pipelines.z_image.pipeline_z_image.get_default_z_image_sigmas
189+
def get_default_z_image_sigmas(num_inference_steps: int) -> list[float]:
190+
return torch.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].tolist()
191+
192+
188193
class ZImageControlNetPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
189194
model_cpu_offload_seq = "text_encoder->transformer->vae"
190195
_optional_components = []
@@ -593,7 +598,8 @@ def __call__(
593598
self.scheduler.config.get("base_shift", 0.5),
594599
self.scheduler.config.get("max_shift", 1.15),
595600
)
596-
self.scheduler.sigma_min = 0.0
601+
if sigmas is None:
602+
sigmas = get_default_z_image_sigmas(num_inference_steps)
597603
scheduler_kwargs = {"mu": mu}
598604
timesteps, num_inference_steps = retrieve_timesteps(
599605
self.scheduler,

src/diffusers/pipelines/z_image/pipeline_z_image_controlnet_inpaint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ def retrieve_timesteps(
185185
return timesteps, num_inference_steps
186186

187187

188+
# Copied from diffusers.pipelines.z_image.pipeline_z_image.get_default_z_image_sigmas
189+
def get_default_z_image_sigmas(num_inference_steps: int) -> list[float]:
190+
return torch.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].tolist()
191+
192+
188193
class ZImageControlNetInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
189194
model_cpu_offload_seq = "text_encoder->transformer->vae"
190195
_optional_components = []
@@ -615,7 +620,8 @@ def __call__(
615620
self.scheduler.config.get("base_shift", 0.5),
616621
self.scheduler.config.get("max_shift", 1.15),
617622
)
618-
self.scheduler.sigma_min = 0.0
623+
if sigmas is None:
624+
sigmas = get_default_z_image_sigmas(num_inference_steps)
619625
scheduler_kwargs = {"mu": mu}
620626
timesteps, num_inference_steps = retrieve_timesteps(
621627
self.scheduler,

src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,11 @@ def retrieve_timesteps(
146146
return timesteps, num_inference_steps
147147

148148

149+
# Copied from diffusers.pipelines.z_image.pipeline_z_image.get_default_z_image_sigmas
150+
def get_default_z_image_sigmas(num_inference_steps: int) -> list[float]:
151+
return torch.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].tolist()
152+
153+
149154
class ZImageImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
150155
r"""
151156
The ZImage pipeline for image-to-image generation.
@@ -563,7 +568,8 @@ def __call__(
563568
self.scheduler.config.get("base_shift", 0.5),
564569
self.scheduler.config.get("max_shift", 1.15),
565570
)
566-
self.scheduler.sigma_min = 0.0
571+
if sigmas is None:
572+
sigmas = get_default_z_image_sigmas(num_inference_steps)
567573
scheduler_kwargs = {"mu": mu}
568574
timesteps, num_inference_steps = retrieve_timesteps(
569575
self.scheduler,

src/diffusers/pipelines/z_image/pipeline_z_image_inpaint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,11 @@ def retrieve_timesteps(
162162
return timesteps, num_inference_steps
163163

164164

165+
# Copied from diffusers.pipelines.z_image.pipeline_z_image.get_default_z_image_sigmas
166+
def get_default_z_image_sigmas(num_inference_steps: int) -> list[float]:
167+
return torch.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].tolist()
168+
169+
165170
class ZImageInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
166171
r"""
167172
The ZImage pipeline for inpainting.
@@ -750,7 +755,8 @@ def __call__(
750755
self.scheduler.config.get("base_shift", 0.5),
751756
self.scheduler.config.get("max_shift", 1.15),
752757
)
753-
self.scheduler.sigma_min = 0.0
758+
if sigmas is None:
759+
sigmas = get_default_z_image_sigmas(num_inference_steps)
754760
scheduler_kwargs = {"mu": mu}
755761
timesteps, num_inference_steps = retrieve_timesteps(
756762
self.scheduler,

src/diffusers/pipelines/z_image/pipeline_z_image_omni.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ def retrieve_timesteps(
135135
return timesteps, num_inference_steps
136136

137137

138+
# Copied from diffusers.pipelines.z_image.pipeline_z_image.get_default_z_image_sigmas
139+
def get_default_z_image_sigmas(num_inference_steps: int) -> list[float]:
140+
return torch.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].tolist()
141+
142+
138143
class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
139144
model_cpu_offload_seq = "text_encoder->transformer->vae"
140145
_optional_components = []
@@ -604,7 +609,8 @@ def __call__(
604609
self.scheduler.config.get("base_shift", 0.5),
605610
self.scheduler.config.get("max_shift", 1.15),
606611
)
607-
self.scheduler.sigma_min = 0.0
612+
if sigmas is None:
613+
sigmas = get_default_z_image_sigmas(num_inference_steps)
608614
scheduler_kwargs = {"mu": mu}
609615
timesteps, num_inference_steps = retrieve_timesteps(
610616
self.scheduler,

0 commit comments

Comments
 (0)