2222import jax .numpy as jnp
2323
2424from ..configuration_utils import ConfigMixin , register_to_config
25+ from ..utils import logging
2526from .scheduling_utils_flax import (
2627 CommonSchedulerState ,
2728 FlaxKarrasDiffusionSchedulers ,
3233)
3334
3435
36+ logger = logging .get_logger (__name__ )
37+
38+
3539@flax .struct .dataclass
3640class DDPMSchedulerState :
3741 common : CommonSchedulerState
@@ -42,7 +46,12 @@ class DDPMSchedulerState:
4246 num_inference_steps : Optional [int ] = None
4347
4448 @classmethod
45- def create (cls , common : CommonSchedulerState , init_noise_sigma : jnp .ndarray , timesteps : jnp .ndarray ):
49+ def create (
50+ cls ,
51+ common : CommonSchedulerState ,
52+ init_noise_sigma : jnp .ndarray ,
53+ timesteps : jnp .ndarray ,
54+ ):
4655 return cls (common = common , init_noise_sigma = init_noise_sigma , timesteps = timesteps )
4756
4857
@@ -105,6 +114,10 @@ def __init__(
105114 prediction_type : str = "epsilon" ,
106115 dtype : jnp .dtype = jnp .float32 ,
107116 ):
117+ logger .warning (
118+ "Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
119+ "recommend migrating to PyTorch classes or pinning your version of Diffusers."
120+ )
108121 self .dtype = dtype
109122
110123 def create_state (self , common : Optional [CommonSchedulerState ] = None ) -> DDPMSchedulerState :
@@ -123,7 +136,10 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSch
123136 )
124137
125138 def scale_model_input (
126- self , state : DDPMSchedulerState , sample : jnp .ndarray , timestep : Optional [int ] = None
139+ self ,
140+ state : DDPMSchedulerState ,
141+ sample : jnp .ndarray ,
142+ timestep : Optional [int ] = None ,
127143 ) -> jnp .ndarray :
128144 """
129145 Args:
0 commit comments