@@ -327,6 +327,96 @@ def test_compatibility_with_monai_generative(self):
327327
328328 net .load_old_state_dict (torch .load (weight_path , weights_only = True ), verbose = False )
329329
330+ @staticmethod
331+ def _new_to_old_sd (new_sd : dict , include_proj_attn : bool = True ) -> dict :
332+ old_sd : dict = {}
333+ for k , v in new_sd .items ():
334+ if ".attn.to_q." in k :
335+ old_sd [k .replace (".attn.to_q." , ".to_q." )] = v .clone ()
336+ elif ".attn.to_k." in k :
337+ old_sd [k .replace (".attn.to_k." , ".to_k." )] = v .clone ()
338+ elif ".attn.to_v." in k :
339+ old_sd [k .replace (".attn.to_v." , ".to_v." )] = v .clone ()
340+ elif ".attn.out_proj." in k :
341+ if include_proj_attn :
342+ old_sd [k .replace (".attn.out_proj." , ".proj_attn." )] = v .clone ()
343+ elif "postconv" in k :
344+ old_sd [k .replace ("postconv" , "conv" )] = v .clone ()
345+ else :
346+ old_sd [k ] = v .clone ()
347+ return old_sd
348+
349+ @skipUnless (has_einops , "Requires einops" )
350+ def test_load_old_state_dict_proj_attn_copied_to_out_proj (self ):
351+ params = {** self ._MIGRATION_PARAMS , "include_fc" : True }
352+ src = AutoencoderKL (** params ).to (device )
353+ old_sd = self ._new_to_old_sd (src .state_dict (), include_proj_attn = True )
354+
355+ # record the tensor values that were stored under proj_attn
356+ expected = {k .replace (".proj_attn." , ".attn.out_proj." ): v for k , v in old_sd .items () if ".proj_attn." in k }
357+ self .assertGreater (len (expected ), 0 , "No proj_attn keys in old state dict – check model config" )
358+
359+ dst = AutoencoderKL (** params ).to (device )
360+ dst .load_old_state_dict (old_sd )
361+
362+ for new_key , expected_val in expected .items ():
363+ torch .testing .assert_close (
364+ dst .state_dict ()[new_key ],
365+ expected_val .to (device ),
366+ msg = f"Weight mismatch for { new_key } " ,
367+ )
368+
369+ @skipUnless (has_einops , "Requires einops" )
370+ def test_load_old_state_dict_missing_proj_attn_initialises_identity (self ):
371+ params = {** self ._MIGRATION_PARAMS , "include_fc" : True }
372+ src = AutoencoderKL (** params ).to (device )
373+ old_sd = self ._new_to_old_sd (src .state_dict (), include_proj_attn = False )
374+
375+ dst = AutoencoderKL (** params ).to (device )
376+ dst .load_old_state_dict (old_sd )
377+ loaded = dst .state_dict ()
378+
379+ out_proj_weights = [k for k in loaded if "attn.out_proj.weight" in k ]
380+ out_proj_biases = [k for k in loaded if "attn.out_proj.bias" in k ]
381+ self .assertGreater (len (out_proj_weights ), 0 , "No out_proj keys found – check model config" )
382+
383+ for k in out_proj_weights :
384+ n = loaded [k ].shape [0 ]
385+ torch .testing .assert_close (
386+ loaded [k ],
387+ torch .eye (n , device = device ),
388+ msg = f"{ k } should be an identity matrix" ,
389+ )
390+ for k in out_proj_biases :
391+ torch .testing .assert_close (
392+ loaded [k ],
393+ torch .zeros_like (loaded [k ]),
394+ msg = f"{ k } should be all-zeros" ,
395+ )
396+
397+ @skipUnless (has_einops , "Requires einops" )
398+ def test_load_old_state_dict_proj_attn_discarded_when_no_out_proj (self ):
399+ params = {** self ._MIGRATION_PARAMS , "include_fc" : False }
400+ src = AutoencoderKL (** params ).to (device )
401+ old_sd = self ._new_to_old_sd (src .state_dict (), include_proj_attn = False )
402+
403+ # inject synthetic proj_attn keys (mimic an old checkpoint)
404+ attn_blocks = [k .replace (".to_q.weight" , "" ) for k in old_sd if k .endswith (".to_q.weight" )]
405+ self .assertGreater (len (attn_blocks ), 0 , "No attention blocks found – check model config" )
406+ for block in attn_blocks :
407+ ch = old_sd [f"{ block } .to_q.weight" ].shape [0 ]
408+ old_sd [f"{ block } .proj_attn.weight" ] = torch .randn (ch , ch )
409+ old_sd [f"{ block } .proj_attn.bias" ] = torch .randn (ch )
410+
411+ dst = AutoencoderKL (** params ).to (device )
412+ dst .load_old_state_dict (old_sd )
413+
414+ loaded = dst .state_dict ()
415+ self .assertFalse (
416+ any ("out_proj" in k for k in loaded ),
417+ "out_proj should not exist in a model built with include_fc=False" ,
418+ )
419+
330420
331421if __name__ == "__main__" :
332422 unittest .main ()
0 commit comments