Skip to content

Commit 1d32b19

Browse files
authored
Improve docstrings and type hints in scheduling_ddim_flax.py (#13010)
* docs: improve docstring scheduling_ddim_flax.py * docs: improve docstring scheduling_ddim_flax.py * docs: improve docstring scheduling_ddim_flax.py
1 parent 699297f commit 1d32b19

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

src/diffusers/schedulers/scheduling_ddim_flax.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax.numpy as jnp
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
25+
from ..utils import logging
2526
from .scheduling_utils_flax import (
2627
CommonSchedulerState,
2728
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@
3233
)
3334

3435

36+
logger = logging.get_logger(__name__)
37+
38+
3539
@flax.struct.dataclass
3640
class DDIMSchedulerState:
3741
common: CommonSchedulerState
@@ -125,6 +129,10 @@ def __init__(
125129
prediction_type: str = "epsilon",
126130
dtype: jnp.dtype = jnp.float32,
127131
):
132+
logger.warning(
133+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
134+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
135+
)
128136
self.dtype = dtype
129137

130138
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState:
@@ -152,7 +160,10 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSch
152160
)
153161

154162
def scale_model_input(
155-
self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
163+
self,
164+
state: DDIMSchedulerState,
165+
sample: jnp.ndarray,
166+
timestep: Optional[int] = None,
156167
) -> jnp.ndarray:
157168
"""
158169
Args:
@@ -190,7 +201,9 @@ def set_timesteps(
190201
def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep):
191202
alpha_prod_t = state.common.alphas_cumprod[timestep]
192203
alpha_prod_t_prev = jnp.where(
193-
prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod
204+
prev_timestep >= 0,
205+
state.common.alphas_cumprod[prev_timestep],
206+
state.final_alpha_cumprod,
194207
)
195208
beta_prod_t = 1 - alpha_prod_t
196209
beta_prod_t_prev = 1 - alpha_prod_t_prev

0 commit comments

Comments
 (0)