Skip to content

Commit 8070f6e

Browse files
NIK-TIGER-BILLNIK-TIGER-BILLyiyixuxu
authored
fix(ddim): validate eta is in [0, 1] in DDIMPipeline (#13367)
* fix(ddim): validate eta is in [0, 1] in DDIMPipeline.__call__ The DDIM paper defines η (eta) as a value that must lie in [0, 1]: η=0 corresponds to deterministic DDIM, η=1 corresponds to DDPM. The docstring already documented this constraint, but no runtime validation was in place, so users could silently pass out-of-range values (e.g. negative or >1) without any error. Add an explicit ValueError check before the denoising loop so that invalid eta values are caught early with a clear message. Fixes #13362 Signed-off-by: NIK-TIGER-BILL <nik.tiger.bill@github.com> * fix(ddim): downgrade eta out-of-range from error to warning Per maintainer feedback from @yiyixuxu — the documentation is sufficient; a hard ValueError is too strict. Replace with a UserWarning so callers are informed without breaking existing code that passes eta outside [0, 1]. Signed-off-by: NIK-TIGER-BILL <nik.tiger.bill@github.com> * fix(ddim): use logger.warning instead of warnings.warn for eta validation Address review request from @yiyixuxu: switch from warnings.warn() to logger.warning() to be consistent with all other diffusers pipelines. The eta validation check itself (0.0 <= eta <= 1.0) is unchanged. Signed-off-by: NIK-TIGER-BILL <nik.tiger.bill@github.com> --------- Signed-off-by: NIK-TIGER-BILL <nik.tiger.bill@github.com> Co-authored-by: NIK-TIGER-BILL <nik.tiger.bill@github.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
1 parent 3e53a38 commit 8070f6e

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

src/diffusers/pipelines/ddim/pipeline_ddim.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import logging
16+
1517
import torch
1618

1719
from ...models import UNet2DModel
@@ -21,6 +23,9 @@
2123
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2224

2325

26+
logger = logging.getLogger(__name__)
27+
28+
2429
if is_torch_xla_available():
2530
import torch_xla.core.xla_model as xm
2631

@@ -129,6 +134,13 @@ def __call__(
129134
else:
130135
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
131136

137+
if not 0.0 <= eta <= 1.0:
138+
logger.warning(
139+
f"`eta` should be between 0 and 1 (inclusive), but received {eta}. "
140+
"A value of 0 corresponds to DDIM and 1 corresponds to DDPM. "
141+
"Unexpected results may occur for values outside this range."
142+
)
143+
132144
if isinstance(generator, list) and len(generator) != batch_size:
133145
raise ValueError(
134146
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"

0 commit comments

Comments
 (0)