169169
170170
171171class TestAutoEncoderKL (unittest .TestCase ):
172+ _MIGRATION_PARAMS = {
173+ "spatial_dims" : 2 ,
174+ "in_channels" : 1 ,
175+ "out_channels" : 1 ,
176+ "channels" : (4 , 4 , 4 ),
177+ "latent_channels" : 4 ,
178+ "attention_levels" : (False , False , False ),
179+ "num_res_blocks" : 1 ,
180+ "norm_num_groups" : 4 ,
181+ }
182+
172183 @parameterized .expand (CASES )
173184 def test_shape (self , input_param , input_shape , expected_shape , expected_latent_shape ):
174185 net = AutoencoderKL (** input_param ).to (device )
@@ -329,6 +340,15 @@ def test_compatibility_with_monai_generative(self):
329340
330341 @staticmethod
331342 def _new_to_old_sd (new_sd : dict , include_proj_attn : bool = True ) -> dict :
343+ """Convert new-style state dict keys to legacy naming conventions.
344+
345+ Args:
346+ new_sd: State dict with current key naming.
347+ include_proj_attn: If True, map `.attn.out_proj.` to `.proj_attn.`.
348+
349+ Returns:
350+ State dict with legacy key names.
351+ """
332352 old_sd : dict = {}
333353 for k , v in new_sd .items ():
334354 if ".attn.to_q." in k :
@@ -354,7 +374,7 @@ def test_load_old_state_dict_proj_attn_copied_to_out_proj(self):
354374
355375 # record the tensor values that were stored under proj_attn
356376 expected = {k .replace (".proj_attn." , ".attn.out_proj." ): v for k , v in old_sd .items () if ".proj_attn." in k }
357- self .assertGreater (len (expected ), 0 , "No proj_attn keys in old state dict – check model config" )
377+ self .assertGreater (len (expected ), 0 , "No proj_attn keys in old state dict - check model config" )
358378
359379 dst = AutoencoderKL (** params ).to (device )
360380 dst .load_old_state_dict (old_sd )
0 commit comments