Skip to content

Commit d8d163c

Browse files
authored
Merge pull request #573 from Kosinkadink/fix/issues-569-572-cast-bias-weight
fix: update cast_bias_weight usage and fix motion model dtype loading
2 parents 90fb133 + 8c277d9 commit d8d163c

4 files changed

Lines changed: 14 additions & 10 deletions

File tree

animatediff/adapter_hellomeme.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def load_hmreferenceadapter(model_name: str):
7070
else:
7171
ops = comfy.ops.manual_cast
7272
hmref = HMReferenceAdapter(ops=ops)
73-
hmref.to(comfy.model_management.unet_dtype())
7473
hmref.to(comfy.model_management.unet_offload_device())
7574
load_result = hmref.load_state_dict(state_dict, strict=True)
75+
hmref.to(comfy.model_management.unet_dtype())
7676
hmref_model = create_HMModelPatcher(model=hmref, load_device=comfy.model_management.get_torch_device(),
7777
offload_device=comfy.model_management.unet_offload_device())
7878
return hmref_model

animatediff/model_injection.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -841,9 +841,9 @@ def load_motion_module_gen1(model_name: str, model: ModelPatcher, motion_lora: M
841841
mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings)
842842
# initialize AnimateDiffModelWrapper
843843
ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info)
844-
ad_wrapper.to(model.model_dtype())
845844
ad_wrapper.to(model.offload_device)
846845
load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=False)
846+
ad_wrapper.to(model.model_dtype())
847847
verify_load_result(load_result=load_result, mm_info=mm_info)
848848
# wrap motion_module into a ModelPatcher, to allow motion lora patches
849849
motion_model = create_MotionModelPatcher(model=ad_wrapper, load_device=model.load_device, offload_device=model.offload_device)
@@ -865,9 +865,9 @@ def load_motion_module_gen2(model_name: str, motion_model_settings: AnimateDiffS
865865
mm_state_dict = apply_mm_settings(model_dict=mm_state_dict, mm_settings=motion_model_settings)
866866
# initialize AnimateDiffModelWrapper
867867
ad_wrapper = AnimateDiffModel(mm_state_dict=mm_state_dict, mm_info=mm_info)
868-
ad_wrapper.to(comfy.model_management.unet_dtype())
869868
ad_wrapper.to(comfy.model_management.unet_offload_device())
870869
load_result = ad_wrapper.load_state_dict(mm_state_dict, strict=False)
870+
ad_wrapper.to(comfy.model_management.unet_dtype())
871871
verify_load_result(load_result=load_result, mm_info=mm_info)
872872
# wrap motion_module into a ModelPatcher, to allow motion lora patches
873873
motion_model = create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(),
@@ -907,34 +907,34 @@ def verify_load_result(load_result: IncompatibleKeys, mm_info: AnimateDiffInfo):
907907

908908
def create_fresh_motion_module(motion_model: MotionModelPatcher) -> MotionModelPatcher:
909909
ad_wrapper = AnimateDiffModel(mm_state_dict=motion_model.model.state_dict(), mm_info=motion_model.model.mm_info)
910-
ad_wrapper.to(comfy.model_management.unet_dtype())
911910
ad_wrapper.to(comfy.model_management.unet_offload_device())
912911
ad_wrapper.load_state_dict(motion_model.model.state_dict())
912+
ad_wrapper.to(comfy.model_management.unet_dtype())
913913
return create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(),
914914
offload_device=comfy.model_management.unet_offload_device())
915915

916916

917917
def create_fresh_encoder_only_model(motion_model: MotionModelPatcher) -> MotionModelPatcher:
918918
ad_wrapper = EncoderOnlyAnimateDiffModel(mm_state_dict=motion_model.model.state_dict(), mm_info=motion_model.model.mm_info)
919-
ad_wrapper.to(comfy.model_management.unet_dtype())
920919
ad_wrapper.to(comfy.model_management.unet_offload_device())
921920
ad_wrapper.load_state_dict(motion_model.model.state_dict(), strict=False)
921+
ad_wrapper.to(comfy.model_management.unet_dtype())
922922
return create_MotionModelPatcher(model=ad_wrapper, load_device=comfy.model_management.get_torch_device(),
923923
offload_device=comfy.model_management.unet_offload_device())
924924

925925

926926
def inject_img_encoder_into_model(motion_model: MotionModelPatcher, w_encoder: MotionModelPatcher):
927927
motion_model.model.init_img_encoder()
928-
motion_model.model.img_encoder.to(comfy.model_management.unet_dtype())
929928
motion_model.model.img_encoder.to(comfy.model_management.unet_offload_device())
930929
motion_model.model.img_encoder.load_state_dict(w_encoder.model.img_encoder.state_dict())
930+
motion_model.model.img_encoder.to(comfy.model_management.unet_dtype())
931931

932932

933933
def inject_pia_conv_in_into_model(motion_model: MotionModelPatcher, w_pia: MotionModelPatcher):
934934
motion_model.model.init_conv_in(w_pia.model.state_dict())
935-
motion_model.model.conv_in.to(comfy.model_management.unet_dtype())
936935
motion_model.model.conv_in.to(comfy.model_management.unet_offload_device())
937936
motion_model.model.conv_in.load_state_dict(w_pia.model.conv_in.state_dict())
937+
motion_model.model.conv_in.to(comfy.model_management.unet_dtype())
938938
motion_model.model.mm_info.mm_format = AnimateDiffFormat.PIA
939939

940940

@@ -956,9 +956,9 @@ def inject_camera_encoder_into_model(motion_model: MotionModelPatcher, camera_ct
956956
# initialize CameraPoseEncoder on motion model, and load keys
957957
camera_encoder = CameraPoseEncoder(channels=motion_model.model.layer_channels, nums_rb=2, ops=motion_model.model.ops).to(
958958
device=comfy.model_management.unet_offload_device(),
959-
dtype=comfy.model_management.unet_dtype()
960959
)
961960
camera_encoder.load_state_dict(camera_state_dict)
961+
camera_encoder.to(dtype=comfy.model_management.unet_dtype())
962962
camera_encoder.temporal_pe_max_len = get_position_encoding_max_len(camera_state_dict, mm_name=camera_ctrl_name, mm_format=AnimateDiffFormat.ANIMATEDIFF)
963963
motion_model.model.set_camera_encoder(camera_encoder=camera_encoder)
964964
# initialize qkv_merge on specific attention blocks, and load keys

animatediff/sampling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,13 @@ def groupnorm_mm_forward(self, input: Tensor) -> Tensor:
203203

204204
input = rearrange(input, "(b f) c h w -> b c f h w", b=batched_conds)
205205
if manual_cast:
206-
weight, bias = comfy.ops.cast_bias_weight(self, input)
206+
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
207207
else:
208208
weight, bias = self.weight, self.bias
209+
offload_stream = None
209210
input = group_norm(input, self.num_groups, weight, bias, self.eps)
211+
if offload_stream is not None:
212+
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
210213
input = rearrange(input, "b c f h w -> (b f) c h w", b=batched_conds)
211214
return input
212215
return groupnorm_mm_forward

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-animatediff-evolved"
33
description = "Improved AnimateDiff integration for ComfyUI."
4-
version = "1.5.6"
4+
version = "1.5.7"
55
license = { file = "LICENSE" }
66
dependencies = []
77

@@ -13,3 +13,4 @@ Repository = "https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved"
1313
PublisherId = "kosinkadink"
1414
DisplayName = "ComfyUI-AnimateDiff-Evolved"
1515
Icon = ""
16+
requires-comfyui = ">=0.3.68"

0 commit comments

Comments
 (0)