Skip to content

Commit e8e88ff

Browse files
authored
Improve docstrings and type hints in scheduling_ddpm_flax.py (#13024)
docs: improve docstring scheduling_ddpm_flax.py
1 parent 6e24cd8 commit e8e88ff

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

src/diffusers/schedulers/scheduling_ddpm_flax.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import jax.numpy as jnp
2323

2424
from ..configuration_utils import ConfigMixin, register_to_config
25+
from ..utils import logging
2526
from .scheduling_utils_flax import (
2627
CommonSchedulerState,
2728
FlaxKarrasDiffusionSchedulers,
@@ -32,6 +33,9 @@
3233
)
3334

3435

36+
logger = logging.get_logger(__name__)
37+
38+
3539
@flax.struct.dataclass
3640
class DDPMSchedulerState:
3741
common: CommonSchedulerState
@@ -42,7 +46,12 @@ class DDPMSchedulerState:
4246
num_inference_steps: Optional[int] = None
4347

4448
@classmethod
45-
def create(cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray):
49+
def create(
50+
cls,
51+
common: CommonSchedulerState,
52+
init_noise_sigma: jnp.ndarray,
53+
timesteps: jnp.ndarray,
54+
):
4655
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps)
4756

4857

@@ -105,6 +114,10 @@ def __init__(
105114
prediction_type: str = "epsilon",
106115
dtype: jnp.dtype = jnp.float32,
107116
):
117+
logger.warning(
118+
"Flax classes are deprecated and will be removed in Diffusers v1.0.0. We "
119+
"recommend migrating to PyTorch classes or pinning your version of Diffusers."
120+
)
108121
self.dtype = dtype
109122

110123
def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSchedulerState:
@@ -123,7 +136,10 @@ def create_state(self, common: Optional[CommonSchedulerState] = None) -> DDPMSch
123136
)
124137

125138
def scale_model_input(
126-
self, state: DDPMSchedulerState, sample: jnp.ndarray, timestep: Optional[int] = None
139+
self,
140+
state: DDPMSchedulerState,
141+
sample: jnp.ndarray,
142+
timestep: Optional[int] = None,
127143
) -> jnp.ndarray:
128144
"""
129145
Args:

0 commit comments

Comments
 (0)