Skip to content

Commit 256f494

Browse files
authored
Merge branch 'main' into feature/zimage-inpaint-pipeline
2 parents fb48046 + 1d32b19 commit 256f494

3 files changed

Lines changed: 96 additions & 11 deletions

File tree

src/diffusers/models/transformers/transformer_longcat_image.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ class LongCatImageTransformer2DModel(
406406
"""
407407

408408
_supports_gradient_checkpointing = True
409+
_repeated_blocks = ["LongCatImageTransformerBlock", "LongCatImageSingleTransformerBlock"]
409410

410411
@register_to_config
411412
def __init__(

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax.numpy as jnp
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
25+
from ..utils import logging
2526
from .scheduling_utils_flax import (
2627
CommonSchedulerState,
2728
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@
3233
)
3334

3435

36+
logger = logging.get_logger(__name__)
37+
38+
3539
@flax.struct.dataclass
3640
class DDIMSchedulerState:
3741
common: CommonSchedulerState
@@ -125,6 +129,10 @@ def __init__(
125129
prediction_type: str = "epsilon",
126130
dtype: jnp.dtype = jnp.float32,
127131
):
132+
logger.warning(
133+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
134+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
135+
)
128136
self.dtype = dtype
129137

130138
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
@@ -152,7 +160,10 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSch
152160
)
153161

154162
def scale_model_input(
155-
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
163+
self,
164+
state: DDIMSchedulerState,
165+
sample: jnp.ndarray,
166+
timestep: Optional[int] = None,
156167
) -> jnp.ndarray:
157168
"""
158169
Args:
@@ -190,7 +201,9 @@ def set_timesteps(
190201
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
191202
alpha_prod_t = state.common.alphas_cumprod[timestep]
192203
alpha_prod_t_prev = jnp.where(
193-
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
204+
prev_timestep >= 0,
205+
state.common.alphas_cumprod[prev_timestep],
206+
state.final_alpha_cumprod,
194207
)
195208
beta_prod_t = 1 - alpha_prod_t
196209
beta_prod_t_prev = 1 - alpha_prod_t_prev

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,7 @@ def __init__(
226226
time_shift_type: Literal["exponential"] = "exponential",
227227
sigma_min: Optional[float] = None,
228228
sigma_max: Optional[float] = None,
229+
shift_terminal: Optional[float] = None,
229230
) -> None:
230231
if self.config.use_beta_sigmas and not is_scipy_available():
231232
raise ImportError("Make sure to install scipy if you want to use beta sigmas.")
@@ -245,6 +246,8 @@ def __init__(
245246
self.betas = betas_for_alpha_bar(num_train_timesteps)
246247
else:
247248
raise NotImplementedError(f"{beta_schedule} is not implemented for {self.__class__}")
249+
if shift_terminal is not None and not use_flow_sigmas:
250+
raise ValueError("`shift_terminal` is only supported when `use_flow_sigmas=True`.")
248251

249252
if rescale_betas_zero_snr:
250253
self.betas = rescale_zero_terminal_snr(self.betas)
@@ -313,8 +316,12 @@ def set_begin_index(self, begin_index: int = 0) -> None:
313316
self._begin_index = begin_index
314317

315318
def set_timesteps(
316-
self, num_inference_steps: int, device: Optional[Union[str, torch.device]] = None, mu: Optional[float] = None
317-
) -> None:
319+
self,
320+
num_inference_steps: Optional[int] = None,
321+
device: Union[str, torch.device] = None,
322+
sigmas: Optional[List[float]] = None,
323+
mu: Optional[float] = None,
324+
):
318325
"""
319326
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
320327
@@ -323,13 +330,24 @@ def set_timesteps(
323330
The number of diffusion steps used when generating samples with a pre-trained model.
324331
device (`str` or `torch.device`, *optional*):
325332
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
333+
sigmas (`List[float]`, *optional*):
334+
Custom values for sigmas to be used for each diffusion step. If `None`, the sigmas are computed
335+
automatically.
326336
mu (`float`, *optional*):
327337
Optional mu parameter for dynamic shifting when using exponential time shift type.
328338
"""
339+
if self.config.use_dynamic_shifting and mu is None:
340+
raise ValueError("`mu` must be passed when `use_dynamic_shifting` is set to be `True`")
341+
342+
if sigmas is not None:
343+
if not self.config.use_flow_sigmas:
344+
raise ValueError(
345+
"Passing `sigmas` is only supported when `use_flow_sigmas=True`. "
346+
"Please set `use_flow_sigmas=True` during scheduler initialization."
347+
)
348+
num_inference_steps = len(sigmas)
349+
329350
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://huggingface.co/papers/2305.08891
330-
if mu is not None:
331-
assert self.config.use_dynamic_shifting and self.config.time_shift_type == "exponential"
332-
self.config.flow_shift = np.exp(mu)
333351
if self.config.timestep_spacing == "linspace":
334352
timesteps = (
335353
np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps + 1)
@@ -354,8 +372,9 @@ def set_timesteps(
354372
f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
355373
)
356374

357-
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
358375
if self.config.use_karras_sigmas:
376+
if sigmas is None:
377+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
359378
log_sigmas = np.log(sigmas)
360379
sigmas = np.flip(sigmas).copy()
361380
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -375,6 +394,8 @@ def set_timesteps(
375394
)
376395
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
377396
elif self.config.use_exponential_sigmas:
397+
if sigmas is None:
398+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
378399
log_sigmas = np.log(sigmas)
379400
sigmas = np.flip(sigmas).copy()
380401
sigmas = self._convert_to_exponential(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -389,6 +410,8 @@ def set_timesteps(
389410
)
390411
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
391412
elif self.config.use_beta_sigmas:
413+
if sigmas is None:
414+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
392415
log_sigmas = np.log(sigmas)
393416
sigmas = np.flip(sigmas).copy()
394417
sigmas = self._convert_to_beta(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
@@ -403,9 +426,18 @@ def set_timesteps(
403426
)
404427
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
405428
elif self.config.use_flow_sigmas:
406-
alphas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)
407-
sigmas = 1.0 - alphas
408-
sigmas = np.flip(self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas))[:-1].copy()
429+
if sigmas is None:
430+
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1]
431+
if self.config.use_dynamic_shifting:
432+
sigmas = self.time_shift(mu, 1.0, sigmas)
433+
else:
434+
sigmas = self.config.flow_shift * sigmas / (1 + (self.config.flow_shift - 1) * sigmas)
435+
if self.config.shift_terminal:
436+
sigmas = self.stretch_shift_to_terminal(sigmas)
437+
eps = 1e-6
438+
if np.fabs(sigmas[0] - 1) < eps:
439+
# to avoid inf torch.log(alpha_si) in multistep_uni_p_bh_update during first/second update
440+
sigmas[0] -= eps
409441
timesteps = (sigmas * self.config.num_train_timesteps).copy()
410442
if self.config.final_sigmas_type == "sigma_min":
411443
sigma_last = sigmas[-1]
@@ -417,6 +449,8 @@ def set_timesteps(
417449
)
418450
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
419451
else:
452+
if sigmas is None:
453+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
420454
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
421455
if self.config.final_sigmas_type == "sigma_min":
422456
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
@@ -446,6 +480,43 @@ def set_timesteps(
446480
self._begin_index = None
447481
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
448482

483+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
484+
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
485+
if self.config.time_shift_type == "exponential":
486+
return self._time_shift_exponential(mu, sigma, t)
487+
elif self.config.time_shift_type == "linear":
488+
return self._time_shift_linear(mu, sigma, t)
489+
490+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.stretch_shift_to_terminal
491+
def stretch_shift_to_terminal(self, t: torch.Tensor) -> torch.Tensor:
492+
r"""
493+
Stretches and shifts the timestep schedule to ensure it terminates at the configured `shift_terminal` config
494+
value.
495+
496+
Reference:
497+
https://github.com/Lightricks/LTX-Video/blob/a01a171f8fe3d99dce2728d60a73fecf4d4238ae/ltx_video/schedulers/rf.py#L51
498+
499+
Args:
500+
t (`torch.Tensor`):
501+
A tensor of timesteps to be stretched and shifted.
502+
503+
Returns:
504+
`torch.Tensor`:
505+
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
506+
"""
507+
one_minus_z = 1 - t
508+
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
509+
stretched_t = 1 - (one_minus_z / scale_factor)
510+
return stretched_t
511+
512+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
513+
def _time_shift_exponential(self, mu, sigma, t):
514+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
515+
516+
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
517+
def _time_shift_linear(self, mu, sigma, t):
518+
return mu / (mu + (1 / t - 1) ** sigma)
519+
449520
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
450521
def _threshold_sample(self, sample: torch.Tensor) -> torch.Tensor:
451522
"""

0 commit comments

Comments
 (0)