Skip to content

Commit f18ad5e

Browse files
ytl0623ericspod
andauthored
Fix dead code and logic error (#8714)
### Description 1. added missing `SPADE` test cases to `LATENT_CNDM_TEST_CASES` 2. fixed the `SPADE` branch in `test_sample_intermediates` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [ ] Non-breaking change (fix or new feature that would not break existing functionality). - [x] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 <david89062388@gmail.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 4a601cc commit f18ad5e

File tree

1 file changed

+57
-16
lines changed

1 file changed

+57
-16
lines changed

tests/inferers/test_controlnet_inferers.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,45 @@
201201
(1, 1, 16, 16, 16),
202202
(1, 3, 4, 4, 4),
203203
],
204+
[
205+
"SPADEAutoencoderKL",
206+
{
207+
"spatial_dims": 2,
208+
"in_channels": 1,
209+
"out_channels": 1,
210+
"channels": (4, 4),
211+
"latent_channels": 3,
212+
"attention_levels": [False, False],
213+
"num_res_blocks": 1,
214+
"norm_num_groups": 4,
215+
"label_nc": 5,
216+
},
217+
"SPADEDiffusionModelUNet",
218+
{
219+
"spatial_dims": 2,
220+
"in_channels": 3,
221+
"out_channels": 3,
222+
"channels": [4, 4],
223+
"norm_num_groups": 4,
224+
"attention_levels": [False, False],
225+
"num_res_blocks": 1,
226+
"num_head_channels": 4,
227+
"label_nc": 5,
228+
},
229+
{
230+
"spatial_dims": 2,
231+
"in_channels": 3,
232+
"channels": [4, 4],
233+
"attention_levels": [False, False],
234+
"num_res_blocks": 1,
235+
"norm_num_groups": 4,
236+
"num_head_channels": 4,
237+
"conditioning_embedding_num_channels": [16],
238+
"conditioning_embedding_in_channels": 1,
239+
},
240+
(1, 1, 8, 8),
241+
(1, 3, 4, 4),
242+
],
204243
]
205244
LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [
206245
[
@@ -661,7 +700,7 @@ def test_normal_cdf(self):
661700
x = torch.linspace(-10, 10, 20)
662701
cdf_approx = inferer._approx_standard_normal_cdf(x)
663702
cdf_true = norm.cdf(x)
664-
torch.testing.assert_allclose(cdf_approx, cdf_true, atol=1e-3, rtol=1e-5)
703+
torch.testing.assert_close(cdf_approx, torch.as_tensor(cdf_true, dtype=cdf_approx.dtype), atol=1e-3, rtol=1e-5)
665704

666705
@parameterized.expand(CNDM_TEST_CASES)
667706
@skipUnless(has_einops, "Requires einops")
@@ -742,6 +781,8 @@ def test_prediction_shape(
742781
stage_1 = AutoencoderKL(**autoencoder_params)
743782
if ae_model_type == "VQVAE":
744783
stage_1 = VQVAE(**autoencoder_params)
784+
if ae_model_type == "SPADEAutoencoderKL":
785+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
745786
if dm_model_type == "SPADEDiffusionModelUNet":
746787
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
747788
else:
@@ -764,7 +805,7 @@ def test_prediction_shape(
764805
inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
765806
scheduler.set_timesteps(num_inference_steps=10)
766807
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
767-
if dm_model_type == "SPADEDiffusionModelUNet":
808+
if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
768809
input_shape_seg = list(input_shape)
769810
if "label_nc" in stage_2_params.keys():
770811
input_shape_seg[1] = stage_2_params["label_nc"]
@@ -807,14 +848,16 @@ def test_pred_shape(
807848
):
808849
stage_1 = None
809850

810-
if ae_model_type == "AutoencoderKL":
811-
stage_1 = AutoencoderKL(**autoencoder_params)
812-
if ae_model_type == "VQVAE":
813-
stage_1 = VQVAE(**autoencoder_params)
814851
if dm_model_type == "SPADEDiffusionModelUNet":
815852
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
816853
else:
817854
stage_2 = DiffusionModelUNet(**stage_2_params)
855+
if ae_model_type == "AutoencoderKL":
856+
stage_1 = AutoencoderKL(**autoencoder_params)
857+
if ae_model_type == "VQVAE":
858+
stage_1 = VQVAE(**autoencoder_params)
859+
if ae_model_type == "SPADEAutoencoderKL":
860+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
818861
controlnet = ControlNet(**controlnet_params)
819862

820863
device = "cuda:0" if torch.cuda.is_available() else "cpu"
@@ -905,19 +948,17 @@ def test_sample_intermediates(
905948
else:
906949
input_shape_seg[1] = autoencoder_params["label_nc"]
907950
input_seg = torch.randn(input_shape_seg).to(device)
908-
sample = inferer.sample(
951+
sample, intermediates = inferer.sample(
909952
input_noise=noise,
910953
autoencoder_model=stage_1,
911954
diffusion_model=stage_2,
912955
scheduler=scheduler,
913956
seg=input_seg,
914957
controlnet=controlnet,
915958
cn_cond=mask,
959+
save_intermediates=True,
960+
intermediate_steps=1,
916961
)
917-
918-
# TODO: this isn't correct, should the above produce intermediates as well?
919-
# This test has always passed so is this branch not being used?
920-
intermediates = None
921962
else:
922963
sample, intermediates = inferer.sample(
923964
input_noise=noise,
@@ -973,7 +1014,7 @@ def test_get_likelihoods(
9731014
inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
9741015
scheduler.set_timesteps(num_inference_steps=10)
9751016

976-
if dm_model_type == "SPADEDiffusionModelUNet":
1017+
if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
9771018
input_shape_seg = list(input_shape)
9781019
if "label_nc" in stage_2_params.keys():
9791020
input_shape_seg[1] = stage_2_params["label_nc"]
@@ -1043,7 +1084,7 @@ def test_resample_likelihoods(
10431084
inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
10441085
scheduler.set_timesteps(num_inference_steps=10)
10451086

1046-
if dm_model_type == "SPADEDiffusionModelUNet":
1087+
if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
10471088
input_shape_seg = list(input_shape)
10481089
if "label_nc" in stage_2_params.keys():
10491090
input_shape_seg[1] = stage_2_params["label_nc"]
@@ -1127,7 +1168,7 @@ def test_prediction_shape_conditioned_concat(
11271168

11281169
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
11291170

1130-
if dm_model_type == "SPADEDiffusionModelUNet":
1171+
if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
11311172
input_shape_seg = list(input_shape)
11321173
if "label_nc" in stage_2_params.keys():
11331174
input_shape_seg[1] = stage_2_params["label_nc"]
@@ -1209,7 +1250,7 @@ def test_sample_shape_conditioned_concat(
12091250
inferer = ControlNetLatentDiffusionInferer(scheduler=scheduler, scale_factor=1.0)
12101251
scheduler.set_timesteps(num_inference_steps=10)
12111252

1212-
if dm_model_type == "SPADEDiffusionModelUNet":
1253+
if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet":
12131254
input_shape_seg = list(input_shape)
12141255
if "label_nc" in stage_2_params.keys():
12151256
input_shape_seg[1] = stage_2_params["label_nc"]
@@ -1290,7 +1331,7 @@ def test_shape_different_latents(
12901331

12911332
timesteps = torch.randint(0, scheduler.num_train_timesteps, (input_shape[0],), device=input.device).long()
12921333

1293-
if dm_model_type == "SPADEDiffusionModelUNet":
1334+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
12941335
input_shape_seg = list(input_shape)
12951336
if "label_nc" in stage_2_params.keys():
12961337
input_shape_seg[1] = stage_2_params["label_nc"]

0 commit comments

Comments
 (0)