|
| 1 | +import inspect |
1 | 2 | from unittest.mock import MagicMock, patch |
2 | 3 |
|
3 | 4 | import pytest |
|
9 | 10 | from invokeai.app.invocations.flux_denoise import FluxDenoiseInvocation |
10 | 11 | from invokeai.app.invocations.sd3_denoise import SD3DenoiseInvocation |
11 | 12 | from invokeai.app.invocations.z_image_denoise import ZImageDenoiseInvocation |
| 13 | +from invokeai.backend.flux.sampling_utils import clip_timestep_schedule_fractional, get_schedule |
| 14 | +from invokeai.backend.flux.schedulers import ANIMA_SCHEDULER_MAP, FLUX_SCHEDULER_MAP, ZIMAGE_SCHEDULER_MAP |
| 15 | +from invokeai.backend.flux2.sampling_utils import compute_empirical_mu, get_schedule_flux2 |
12 | 16 |
|
13 | 17 |
|
14 | 18 | def test_flux_prepare_noise_uses_external_noise(): |
@@ -165,3 +169,236 @@ def test_anima_prepare_noise_rejects_invalid_rank(): |
165 | 169 |
|
166 | 170 | with pytest.raises(ValueError, match="Expected noise with shape"): |
167 | 171 | invocation._prepare_noise_tensor(mock_context, torch.bfloat16, torch.device("cpu")) |
| 172 | + |
| 173 | + |
| 174 | +def _get_first_scheduler_sigma( |
| 175 | + scheduler, *, scheduler_name: str, sigmas: list[float], mu: float | None = None |
| 176 | +) -> float: |
| 177 | + set_timesteps_signature = inspect.signature(scheduler.set_timesteps) |
| 178 | + if scheduler_name != "lcm" and "sigmas" in set_timesteps_signature.parameters: |
| 179 | + kwargs: dict[str, object] = {"sigmas": sigmas, "device": "cpu"} |
| 180 | + if mu is not None and "mu" in set_timesteps_signature.parameters: |
| 181 | + kwargs["mu"] = mu |
| 182 | + scheduler.set_timesteps(**kwargs) |
| 183 | + else: |
| 184 | + scheduler.set_timesteps(num_inference_steps=len(sigmas) - 1, device="cpu") |
| 185 | + return float(scheduler.sigmas[0]) |
| 186 | + |
| 187 | + |
| 188 | +@pytest.mark.parametrize( |
| 189 | + "scheduler_name", |
| 190 | + [ |
| 191 | + "euler", |
| 192 | + pytest.param( |
| 193 | + "heun", |
| 194 | + marks=pytest.mark.xfail( |
| 195 | + reason="Known img2img preblend mismatch for FLUX with scheduler-defined first step.", |
| 196 | + strict=True, |
| 197 | + ), |
| 198 | + ), |
| 199 | + pytest.param( |
| 200 | + "lcm", |
| 201 | + marks=pytest.mark.xfail( |
| 202 | + reason="Known img2img preblend mismatch for FLUX with scheduler-defined first step.", |
| 203 | + strict=True, |
| 204 | + ), |
| 205 | + ), |
| 206 | + ], |
| 207 | +) |
| 208 | +def test_flux_img2img_preblend_matches_scheduler_first_sigma(scheduler_name: str): |
| 209 | + sigmas = clip_timestep_schedule_fractional(get_schedule(num_steps=4, image_seq_len=16, shift=True), 0.25, 1.0) |
| 210 | + scheduler_class = FLUX_SCHEDULER_MAP[scheduler_name] |
| 211 | + scheduler = scheduler_class(num_train_timesteps=1000) |
| 212 | + |
| 213 | + assert sigmas[0] == pytest.approx( |
| 214 | + _get_first_scheduler_sigma(scheduler, scheduler_name=scheduler_name, sigmas=sigmas) |
| 215 | + ) |
| 216 | + |
| 217 | + |
| 218 | +@pytest.mark.parametrize( |
| 219 | + "scheduler_name", |
| 220 | + [ |
| 221 | + pytest.param( |
| 222 | + "euler", |
| 223 | + marks=pytest.mark.xfail( |
| 224 | + reason="Known img2img preblend mismatch for FLUX.2 scheduler path.", |
| 225 | + strict=True, |
| 226 | + ), |
| 227 | + ), |
| 228 | + pytest.param( |
| 229 | + "heun", |
| 230 | + marks=pytest.mark.xfail( |
| 231 | + reason="Known img2img preblend mismatch for FLUX.2 scheduler path.", |
| 232 | + strict=True, |
| 233 | + ), |
| 234 | + ), |
| 235 | + pytest.param( |
| 236 | + "lcm", |
| 237 | + marks=pytest.mark.xfail( |
| 238 | + reason="Known FLUX.2 scheduler-path limitation for img2img parity.", |
| 239 | + strict=True, |
| 240 | + ), |
| 241 | + ), |
| 242 | + ], |
| 243 | +) |
| 244 | +def test_flux2_img2img_preblend_matches_scheduler_first_sigma(scheduler_name: str): |
| 245 | + sigmas = clip_timestep_schedule_fractional(get_schedule_flux2(num_steps=4, image_seq_len=16), 0.25, 1.0) |
| 246 | + mu = compute_empirical_mu(image_seq_len=16, num_steps=4) |
| 247 | + scheduler_class = FLUX_SCHEDULER_MAP[scheduler_name] |
| 248 | + if scheduler_name == "heun": |
| 249 | + scheduler = scheduler_class(num_train_timesteps=1000, shift=3.0) |
| 250 | + else: |
| 251 | + scheduler = scheduler_class( |
| 252 | + num_train_timesteps=1000, |
| 253 | + shift=3.0, |
| 254 | + use_dynamic_shifting=True, |
| 255 | + base_shift=0.5, |
| 256 | + max_shift=1.15, |
| 257 | + base_image_seq_len=256, |
| 258 | + max_image_seq_len=4096, |
| 259 | + time_shift_type="exponential", |
| 260 | + ) |
| 261 | + |
| 262 | + assert sigmas[0] == pytest.approx( |
| 263 | + _get_first_scheduler_sigma(scheduler, scheduler_name=scheduler_name, sigmas=sigmas[:-1], mu=mu) |
| 264 | + ) |
| 265 | + |
| 266 | + |
| 267 | +@pytest.mark.parametrize( |
| 268 | + "scheduler_name", |
| 269 | + [ |
| 270 | + "euler", |
| 271 | + pytest.param( |
| 272 | + "heun", |
| 273 | + marks=pytest.mark.xfail( |
| 274 | + reason="Known img2img preblend mismatch for Z-Image with scheduler-defined first step.", |
| 275 | + strict=True, |
| 276 | + ), |
| 277 | + ), |
| 278 | + pytest.param( |
| 279 | + "lcm", |
| 280 | + marks=pytest.mark.xfail( |
| 281 | + reason="Known img2img preblend mismatch for Z-Image with scheduler-defined first step.", |
| 282 | + strict=True, |
| 283 | + ), |
| 284 | + ), |
| 285 | + ], |
| 286 | +) |
| 287 | +def test_z_image_img2img_preblend_matches_scheduler_first_sigma(scheduler_name: str): |
| 288 | + invocation = ZImageDenoiseInvocation.model_construct(steps=8, width=1024, height=1024) |
| 289 | + img_seq_len = (invocation.height // 8 // 2) * (invocation.width // 8 // 2) |
| 290 | + shift = invocation._calculate_shift(img_seq_len) |
| 291 | + sigmas = invocation._get_sigmas(shift, invocation.steps) |
| 292 | + sigmas = sigmas[int(0.25 * (len(sigmas) - 1)) :] |
| 293 | + scheduler_class = ZIMAGE_SCHEDULER_MAP[scheduler_name] |
| 294 | + scheduler = scheduler_class(num_train_timesteps=1000, shift=1.0) |
| 295 | + |
| 296 | + assert sigmas[0] == pytest.approx( |
| 297 | + _get_first_scheduler_sigma(scheduler, scheduler_name=scheduler_name, sigmas=sigmas) |
| 298 | + ) |
| 299 | + |
| 300 | + |
| 301 | +@pytest.mark.parametrize( |
| 302 | + "scheduler_name", |
| 303 | + [ |
| 304 | + "euler", |
| 305 | + pytest.param( |
| 306 | + "heun", |
| 307 | + marks=pytest.mark.xfail( |
| 308 | + reason="Known img2img preblend mismatch for Anima with scheduler-defined first step.", |
| 309 | + strict=True, |
| 310 | + ), |
| 311 | + ), |
| 312 | + pytest.param( |
| 313 | + "lcm", |
| 314 | + marks=pytest.mark.xfail( |
| 315 | + reason="Known img2img preblend mismatch for Anima with scheduler-defined first step.", |
| 316 | + strict=True, |
| 317 | + ), |
| 318 | + ), |
| 319 | + ], |
| 320 | +) |
| 321 | +def test_anima_img2img_preblend_matches_scheduler_first_sigma(scheduler_name: str): |
| 322 | + invocation = AnimaDenoiseInvocation.model_construct(steps=30) |
| 323 | + sigmas = invocation._get_sigmas(invocation.steps) |
| 324 | + sigmas = sigmas[int(0.25 * (len(sigmas) - 1)) :] |
| 325 | + scheduler_class = ANIMA_SCHEDULER_MAP[scheduler_name] |
| 326 | + scheduler = scheduler_class(num_train_timesteps=1000, shift=1.0) |
| 327 | + |
| 328 | + assert sigmas[0] == pytest.approx( |
| 329 | + _get_first_scheduler_sigma(scheduler, scheduler_name=scheduler_name, sigmas=sigmas) |
| 330 | + ) |
| 331 | + |
| 332 | + |
| 333 | +def test_sd3_partial_denoise_short_circuit_uses_first_clipped_timestep(): |
| 334 | + invocation = SD3DenoiseInvocation.model_construct( |
| 335 | + latents=MagicMock(latents_name="latents"), |
| 336 | + width=64, |
| 337 | + height=64, |
| 338 | + steps=4, |
| 339 | + denoising_start=0.25, |
| 340 | + denoising_end=0.25, |
| 341 | + positive_conditioning=MagicMock(conditioning_name="positive"), |
| 342 | + negative_conditioning=MagicMock(conditioning_name="negative"), |
| 343 | + transformer=MagicMock(transformer="transformer"), |
| 344 | + seed=0, |
| 345 | + ) |
| 346 | + init_latents = torch.full((1, 16, 8, 8), 2.0) |
| 347 | + noise = torch.full((1, 16, 8, 8), 10.0) |
| 348 | + mock_context = MagicMock() |
| 349 | + mock_context.tensors.load.return_value = init_latents |
| 350 | + mock_context.models.load.return_value = MagicMock( |
| 351 | + model=MagicMock(config=MagicMock(in_channels=16, joint_attention_dim=4096)) |
| 352 | + ) |
| 353 | + |
| 354 | + with ( |
| 355 | + patch("invokeai.app.invocations.sd3_denoise.TorchDevice.choose_torch_device", return_value=torch.device("cpu")), |
| 356 | + patch("invokeai.app.invocations.sd3_denoise.TorchDevice.choose_torch_dtype", return_value=torch.float32), |
| 357 | + patch.object(invocation, "_prepare_noise_tensor", return_value=noise), |
| 358 | + patch.object(invocation, "_load_text_conditioning", return_value=(torch.zeros(1, 1, 1), torch.zeros(1, 1))), |
| 359 | + ): |
| 360 | + result = invocation._run_diffusion(mock_context) |
| 361 | + |
| 362 | + timesteps = clip_timestep_schedule_fractional(torch.linspace(1, 0, invocation.steps + 1).tolist(), 0.25, 0.25) |
| 363 | + expected = timesteps[0] * noise + (1.0 - timesteps[0]) * init_latents |
| 364 | + assert torch.equal(result, expected) |
| 365 | + |
| 366 | + |
| 367 | +def test_cogview4_partial_denoise_short_circuit_uses_first_clipped_sigma(): |
| 368 | + invocation = CogView4DenoiseInvocation.model_construct( |
| 369 | + latents=MagicMock(latents_name="latents"), |
| 370 | + width=64, |
| 371 | + height=64, |
| 372 | + steps=4, |
| 373 | + denoising_start=0.25, |
| 374 | + denoising_end=0.25, |
| 375 | + positive_conditioning=MagicMock(conditioning_name="positive"), |
| 376 | + negative_conditioning=MagicMock(conditioning_name="negative"), |
| 377 | + transformer=MagicMock(transformer="transformer"), |
| 378 | + seed=0, |
| 379 | + ) |
| 380 | + init_latents = torch.full((1, 16, 8, 8), 2.0) |
| 381 | + noise = torch.full((1, 16, 8, 8), 10.0) |
| 382 | + mock_context = MagicMock() |
| 383 | + mock_context.tensors.load.return_value = init_latents |
| 384 | + transformer_model = MagicMock(config=MagicMock(in_channels=16, patch_size=2)) |
| 385 | + mock_context.models.load.return_value = MagicMock(model=transformer_model) |
| 386 | + |
| 387 | + with ( |
| 388 | + patch("invokeai.app.invocations.cogview4_denoise.CogView4Transformer2DModel", object), |
| 389 | + patch( |
| 390 | + "invokeai.app.invocations.cogview4_denoise.TorchDevice.choose_torch_device", |
| 391 | + return_value=torch.device("cpu"), |
| 392 | + ), |
| 393 | + patch.object(invocation, "_prepare_noise_tensor", return_value=noise), |
| 394 | + patch.object(invocation, "_load_text_conditioning", return_value=torch.zeros(1, 1, 1)), |
| 395 | + ): |
| 396 | + result = invocation._run_diffusion(mock_context) |
| 397 | + |
| 398 | + timesteps = clip_timestep_schedule_fractional(torch.linspace(1, 0, invocation.steps + 1).tolist(), 0.25, 0.25) |
| 399 | + sigmas = invocation._convert_timesteps_to_sigmas( |
| 400 | + image_seq_len=((invocation.height // 8) * (invocation.width // 8)) // (2**2), |
| 401 | + timesteps=torch.tensor(timesteps), |
| 402 | + ) |
| 403 | + expected = sigmas[0] * noise + (1.0 - sigmas[0]) * init_latents |
| 404 | + assert torch.allclose(result, expected, atol=2e-3, rtol=0) |
0 commit comments