@@ -841,9 +841,9 @@ def load_motion_module_gen1(model_name: str, model: ModelPatcher, motion_lora: M
841841 mm_state_dict = apply_mm_settings (model_dict = mm_state_dict , mm_settings = motion_model_settings )
842842 # initialize AnimateDiffModelWrapper
843843 ad_wrapper = AnimateDiffModel (mm_state_dict = mm_state_dict , mm_info = mm_info )
844- ad_wrapper .to (model .model_dtype ())
845844 ad_wrapper .to (model .offload_device )
846845 load_result = ad_wrapper .load_state_dict (mm_state_dict , strict = False )
846+ ad_wrapper .to (model .model_dtype ())
847847 verify_load_result (load_result = load_result , mm_info = mm_info )
848848 # wrap motion_module into a ModelPatcher, to allow motion lora patches
849849 motion_model = create_MotionModelPatcher (model = ad_wrapper , load_device = model .load_device , offload_device = model .offload_device )
@@ -865,9 +865,9 @@ def load_motion_module_gen2(model_name: str, motion_model_settings: AnimateDiffS
865865 mm_state_dict = apply_mm_settings (model_dict = mm_state_dict , mm_settings = motion_model_settings )
866866 # initialize AnimateDiffModelWrapper
867867 ad_wrapper = AnimateDiffModel (mm_state_dict = mm_state_dict , mm_info = mm_info )
868- ad_wrapper .to (comfy .model_management .unet_dtype ())
869868 ad_wrapper .to (comfy .model_management .unet_offload_device ())
870869 load_result = ad_wrapper .load_state_dict (mm_state_dict , strict = False )
870+ ad_wrapper .to (comfy .model_management .unet_dtype ())
871871 verify_load_result (load_result = load_result , mm_info = mm_info )
872872 # wrap motion_module into a ModelPatcher, to allow motion lora patches
873873 motion_model = create_MotionModelPatcher (model = ad_wrapper , load_device = comfy .model_management .get_torch_device (),
@@ -907,34 +907,34 @@ def verify_load_result(load_result: IncompatibleKeys, mm_info: AnimateDiffInfo):
907907
908908def create_fresh_motion_module (motion_model : MotionModelPatcher ) -> MotionModelPatcher :
909909 ad_wrapper = AnimateDiffModel (mm_state_dict = motion_model .model .state_dict (), mm_info = motion_model .model .mm_info )
910- ad_wrapper .to (comfy .model_management .unet_dtype ())
911910 ad_wrapper .to (comfy .model_management .unet_offload_device ())
912911 ad_wrapper .load_state_dict (motion_model .model .state_dict ())
912+ ad_wrapper .to (comfy .model_management .unet_dtype ())
913913 return create_MotionModelPatcher (model = ad_wrapper , load_device = comfy .model_management .get_torch_device (),
914914 offload_device = comfy .model_management .unet_offload_device ())
915915
916916
917917def create_fresh_encoder_only_model (motion_model : MotionModelPatcher ) -> MotionModelPatcher :
918918 ad_wrapper = EncoderOnlyAnimateDiffModel (mm_state_dict = motion_model .model .state_dict (), mm_info = motion_model .model .mm_info )
919- ad_wrapper .to (comfy .model_management .unet_dtype ())
920919 ad_wrapper .to (comfy .model_management .unet_offload_device ())
921920 ad_wrapper .load_state_dict (motion_model .model .state_dict (), strict = False )
921+ ad_wrapper .to (comfy .model_management .unet_dtype ())
922922 return create_MotionModelPatcher (model = ad_wrapper , load_device = comfy .model_management .get_torch_device (),
923923 offload_device = comfy .model_management .unet_offload_device ())
924924
925925
926926def inject_img_encoder_into_model (motion_model : MotionModelPatcher , w_encoder : MotionModelPatcher ):
927927 motion_model .model .init_img_encoder ()
928- motion_model .model .img_encoder .to (comfy .model_management .unet_dtype ())
929928 motion_model .model .img_encoder .to (comfy .model_management .unet_offload_device ())
930929 motion_model .model .img_encoder .load_state_dict (w_encoder .model .img_encoder .state_dict ())
930+ motion_model .model .img_encoder .to (comfy .model_management .unet_dtype ())
931931
932932
933933def inject_pia_conv_in_into_model (motion_model : MotionModelPatcher , w_pia : MotionModelPatcher ):
934934 motion_model .model .init_conv_in (w_pia .model .state_dict ())
935- motion_model .model .conv_in .to (comfy .model_management .unet_dtype ())
936935 motion_model .model .conv_in .to (comfy .model_management .unet_offload_device ())
937936 motion_model .model .conv_in .load_state_dict (w_pia .model .conv_in .state_dict ())
937+ motion_model .model .conv_in .to (comfy .model_management .unet_dtype ())
938938 motion_model .model .mm_info .mm_format = AnimateDiffFormat .PIA
939939
940940
@@ -956,9 +956,9 @@ def inject_camera_encoder_into_model(motion_model: MotionModelPatcher, camera_ct
956956 # initialize CameraPoseEncoder on motion model, and load keys
957957 camera_encoder = CameraPoseEncoder (channels = motion_model .model .layer_channels , nums_rb = 2 , ops = motion_model .model .ops ).to (
958958 device = comfy .model_management .unet_offload_device (),
959- dtype = comfy .model_management .unet_dtype ()
960959 )
961960 camera_encoder .load_state_dict (camera_state_dict )
961+ camera_encoder .to (dtype = comfy .model_management .unet_dtype ())
962962 camera_encoder .temporal_pe_max_len = get_position_encoding_max_len (camera_state_dict , mm_name = camera_ctrl_name , mm_format = AnimateDiffFormat .ANIMATEDIFF )
963963 motion_model .model .set_camera_encoder (camera_encoder = camera_encoder )
964964 # initialize qkv_merge on specific attention blocks, and load keys
0 commit comments