|
22 | 22 | import jax.numpy as jnp |
23 | 23 |
|
24 | 24 | from ..configuration_utils import ConfigMixin, register_to_config |
| 25 | +from ..utils import logging |
25 | 26 | from .scheduling_utils_flax import ( |
26 | 27 | CommonSchedulerState, |
27 | 28 | FlaxKarrasDiffusionSchedulers, |
|
32 | 33 | ) |
33 | 34 |
|
34 | 35 |
|
| 36 | +logger = logging.get_logger(__name__) |
| 37 | + |
| 38 | + |
35 | 39 | @flax.struct.dataclass |
36 | 40 | class DDIMSchedulerState: |
37 | 41 | common: CommonSchedulerState |
@@ -125,6 +129,10 @@ def __init__( |
125 | 129 | prediction_type: str = "epsilon", |
126 | 130 | dtype: jnp.dtype = jnp.float32, |
127 | 131 | ): |
| 132 | + logger.warning( |
| 133 | + "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We " |
| 134 | + "recommend migrating to PyTorch classes or pinning your version of Diffusers." |
| 135 | + ) |
128 | 136 | self.dtype = dtype |
129 | 137 |
|
130 | 138 | def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSchedulerState: |
@@ -152,7 +160,10 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDIMSch |
152 | 160 | ) |
153 | 161 |
|
154 | 162 | def scale_model_input( |
155 | | - self, state: DDIMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None |
| 163 | + self, |
| 164 | + state: DDIMSchedulerState, |
| 165 | + sample: jnp.ndarray, |
| 166 | + timestep: Optional[int] = None, |
156 | 167 | ) -> jnp.ndarray: |
157 | 168 | """ |
158 | 169 | Args: |
@@ -190,7 +201,9 @@ def set_timesteps( |
190 | 201 | def _get_variance(self, state: DDIMSchedulerState, timestep, prev_timestep): |
191 | 202 | alpha_prod_t = state.common.alphas_cumprod[timestep] |
192 | 203 | alpha_prod_t_prev = jnp.where( |
193 | | - prev_timestep >= 0, state.common.alphas_cumprod[prev_timestep], state.final_alpha_cumprod |
| 204 | + prev_timestep >= 0, |
| 205 | + state.common.alphas_cumprod[prev_timestep], |
| 206 | + state.final_alpha_cumprod, |
194 | 207 | ) |
195 | 208 | beta_prod_t = 1 - alpha_prod_t |
196 | 209 | beta_prod_t_prev = 1 - alpha_prod_t_prev |
|
0 commit comments