201201 (1 , 1 , 16 , 16 , 16 ),
202202 (1 , 3 , 4 , 4 , 4 ),
203203 ],
204+ [
205+ "SPADEAutoencoderKL" ,
206+ {
207+ "spatial_dims" : 2 ,
208+ "in_channels" : 1 ,
209+ "out_channels" : 1 ,
210+ "channels" : (4 , 4 ),
211+ "latent_channels" : 3 ,
212+ "attention_levels" : [False , False ],
213+ "num_res_blocks" : 1 ,
214+ "norm_num_groups" : 4 ,
215+ "label_nc" : 5 ,
216+ },
217+ "SPADEDiffusionModelUNet" ,
218+ {
219+ "spatial_dims" : 2 ,
220+ "in_channels" : 3 ,
221+ "out_channels" : 3 ,
222+ "channels" : [4 , 4 ],
223+ "norm_num_groups" : 4 ,
224+ "attention_levels" : [False , False ],
225+ "num_res_blocks" : 1 ,
226+ "num_head_channels" : 4 ,
227+ "label_nc" : 5 ,
228+ },
229+ {
230+ "spatial_dims" : 2 ,
231+ "in_channels" : 3 ,
232+ "channels" : [4 , 4 ],
233+ "attention_levels" : [False , False ],
234+ "num_res_blocks" : 1 ,
235+ "norm_num_groups" : 4 ,
236+ "num_head_channels" : 4 ,
237+ "conditioning_embedding_num_channels" : [16 ],
238+ "conditioning_embedding_in_channels" : 1 ,
239+ },
240+ (1 , 1 , 8 , 8 ),
241+ (1 , 3 , 4 , 4 ),
242+ ],
204243]
205244LATENT_CNDM_TEST_CASES_DIFF_SHAPES = [
206245 [
@@ -661,7 +700,7 @@ def test_normal_cdf(self):
661700 x = torch .linspace (- 10 , 10 , 20 )
662701 cdf_approx = inferer ._approx_standard_normal_cdf (x )
663702 cdf_true = norm .cdf (x )
664- torch .testing .assert_allclose (cdf_approx , cdf_true , atol = 1e-3 , rtol = 1e-5 )
703+ torch .testing .assert_close (cdf_approx , torch . as_tensor ( cdf_true , dtype = cdf_approx . dtype ) , atol = 1e-3 , rtol = 1e-5 )
665704
666705 @parameterized .expand (CNDM_TEST_CASES )
667706 @skipUnless (has_einops , "Requires einops" )
@@ -742,6 +781,8 @@ def test_prediction_shape(
742781 stage_1 = AutoencoderKL (** autoencoder_params )
743782 if ae_model_type == "VQVAE" :
744783 stage_1 = VQVAE (** autoencoder_params )
784+ if ae_model_type == "SPADEAutoencoderKL" :
785+ stage_1 = SPADEAutoencoderKL (** autoencoder_params )
745786 if dm_model_type == "SPADEDiffusionModelUNet" :
746787 stage_2 = SPADEDiffusionModelUNet (** stage_2_params )
747788 else :
@@ -764,7 +805,7 @@ def test_prediction_shape(
764805 inferer = ControlNetLatentDiffusionInferer (scheduler = scheduler , scale_factor = 1.0 )
765806 scheduler .set_timesteps (num_inference_steps = 10 )
766807 timesteps = torch .randint (0 , scheduler .num_train_timesteps , (input_shape [0 ],), device = input .device ).long ()
767- if dm_model_type == "SPADEDiffusionModelUNet" :
808+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet" :
768809 input_shape_seg = list (input_shape )
769810 if "label_nc" in stage_2_params .keys ():
770811 input_shape_seg [1 ] = stage_2_params ["label_nc" ]
@@ -807,14 +848,16 @@ def test_pred_shape(
807848 ):
808849 stage_1 = None
809850
810- if ae_model_type == "AutoencoderKL" :
811- stage_1 = AutoencoderKL (** autoencoder_params )
812- if ae_model_type == "VQVAE" :
813- stage_1 = VQVAE (** autoencoder_params )
814851 if dm_model_type == "SPADEDiffusionModelUNet" :
815852 stage_2 = SPADEDiffusionModelUNet (** stage_2_params )
816853 else :
817854 stage_2 = DiffusionModelUNet (** stage_2_params )
855+ if ae_model_type == "AutoencoderKL" :
856+ stage_1 = AutoencoderKL (** autoencoder_params )
857+ if ae_model_type == "VQVAE" :
858+ stage_1 = VQVAE (** autoencoder_params )
859+ if ae_model_type == "SPADEAutoencoderKL" :
860+ stage_1 = SPADEAutoencoderKL (** autoencoder_params )
818861 controlnet = ControlNet (** controlnet_params )
819862
820863 device = "cuda:0" if torch .cuda .is_available () else "cpu"
@@ -905,19 +948,17 @@ def test_sample_intermediates(
905948 else :
906949 input_shape_seg [1 ] = autoencoder_params ["label_nc" ]
907950 input_seg = torch .randn (input_shape_seg ).to (device )
908- sample = inferer .sample (
951+ sample , intermediates = inferer .sample (
909952 input_noise = noise ,
910953 autoencoder_model = stage_1 ,
911954 diffusion_model = stage_2 ,
912955 scheduler = scheduler ,
913956 seg = input_seg ,
914957 controlnet = controlnet ,
915958 cn_cond = mask ,
959+ save_intermediates = True ,
960+ intermediate_steps = 1 ,
916961 )
917-
918- # TODO: this isn't correct, should the above produce intermediates as well?
919- # This test has always passed so is this branch not being used?
920- intermediates = None
921962 else :
922963 sample , intermediates = inferer .sample (
923964 input_noise = noise ,
@@ -973,7 +1014,7 @@ def test_get_likelihoods(
9731014 inferer = ControlNetLatentDiffusionInferer (scheduler = scheduler , scale_factor = 1.0 )
9741015 scheduler .set_timesteps (num_inference_steps = 10 )
9751016
976- if dm_model_type == "SPADEDiffusionModelUNet" :
1017+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet" :
9771018 input_shape_seg = list (input_shape )
9781019 if "label_nc" in stage_2_params .keys ():
9791020 input_shape_seg [1 ] = stage_2_params ["label_nc" ]
@@ -1043,7 +1084,7 @@ def test_resample_likelihoods(
10431084 inferer = ControlNetLatentDiffusionInferer (scheduler = scheduler , scale_factor = 1.0 )
10441085 scheduler .set_timesteps (num_inference_steps = 10 )
10451086
1046- if dm_model_type == "SPADEDiffusionModelUNet" :
1087+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet" :
10471088 input_shape_seg = list (input_shape )
10481089 if "label_nc" in stage_2_params .keys ():
10491090 input_shape_seg [1 ] = stage_2_params ["label_nc" ]
@@ -1127,7 +1168,7 @@ def test_prediction_shape_conditioned_concat(
11271168
11281169 timesteps = torch .randint (0 , scheduler .num_train_timesteps , (input_shape [0 ],), device = input .device ).long ()
11291170
1130- if dm_model_type == "SPADEDiffusionModelUNet" :
1171+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet" :
11311172 input_shape_seg = list (input_shape )
11321173 if "label_nc" in stage_2_params .keys ():
11331174 input_shape_seg [1 ] = stage_2_params ["label_nc" ]
@@ -1209,7 +1250,7 @@ def test_sample_shape_conditioned_concat(
12091250 inferer = ControlNetLatentDiffusionInferer (scheduler = scheduler , scale_factor = 1.0 )
12101251 scheduler .set_timesteps (num_inference_steps = 10 )
12111252
1212- if dm_model_type == "SPADEDiffusionModelUNet" :
1253+ if ae_model_type == "SPADEAutoencoderKL" or dm_model_type == "SPADEDiffusionModelUNet" :
12131254 input_shape_seg = list (input_shape )
12141255 if "label_nc" in stage_2_params .keys ():
12151256 input_shape_seg [1 ] = stage_2_params ["label_nc" ]
@@ -1290,7 +1331,7 @@ def test_shape_different_latents(
12901331
12911332 timesteps = torch .randint (0 , scheduler .num_train_timesteps , (input_shape [0 ],), device = input .device ).long ()
12921333
1293- if dm_model_type == "SPADEDiffusionModelUNet" :
1334+ if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL" :
12941335 input_shape_seg = list (input_shape )
12951336 if "label_nc" in stage_2_params .keys ():
12961337 input_shape_seg [1 ] = stage_2_params ["label_nc" ]
0 commit comments