Skip to content

Commit 9861558

Browse files
committed
Add comfyui auto_target_shape for animate model infer
1 parent b7b1165 commit 9861558

6 files changed

Lines changed: 58 additions & 19 deletions

File tree

configs/platforms/mlu/wan_animate.json

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
{
22
"infer_steps": 4,
3-
"target_video_length": 77,
3+
"target_video_length": 81,
44
"text_len": 512,
5-
"target_height": 720,
6-
"target_width": 1280,
5+
"auto_target_shape": false,
6+
"target_height": 1280,
7+
"target_width": 720,
78
"self_attn_1_type": "mlu_flash_attn",
89
"cross_attn_1_type": "mlu_flash_attn",
910
"cross_attn_2_type": "mlu_flash_attn",
@@ -18,9 +19,7 @@
1819
"rms_norm_type": "mlu_rms_norm",
1920
"refert_num": 1,
2021
"replace_flag": false,
21-
"fps": 30,
22-
"denoising_step_list": [1000, 750, 500, 250],
23-
"scheduler_type": "WanStepDistillScheduler",
22+
"fps": 24,
2423
"lora_configs": [
2524
{
2625
"path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors",
Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
{
22
"infer_steps": 4,
3-
"target_video_length": 77,
3+
"target_video_length": 81,
44
"text_len": 512,
5-
"target_height": 720,
6-
"target_width": 1280,
5+
"auto_target_shape": false,
6+
"target_height": 1280,
7+
"target_width": 720,
78
"self_attn_1_type": "mlu_flash_attn",
89
"cross_attn_1_type": "mlu_flash_attn",
910
"cross_attn_2_type": "mlu_flash_attn",
@@ -18,17 +19,15 @@
1819
"rms_norm_type": "mlu_rms_norm",
1920
"refert_num": 1,
2021
"replace_flag": false,
21-
"fps": 30,
22-
"denoising_step_list": [1000, 750, 500, 250],
23-
"scheduler_type": "WanStepDistillScheduler",
22+
"fps": 24,
2423
"lora_configs": [
2524
{
2625
"path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors",
2726
"strength": 1.0
2827
}
2928
],
3029
"parallel": {
31-
"seq_p_size": 8,
30+
"seq_p_size": 2,
3231
"seq_p_attn_type": "ulysses"
3332
}
3433
}

configs/wan22/wan_animate_lora.json

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
"refert_num": 1,
1616
"replace_flag": false,
1717
"fps": 30,
18-
"denoising_step_list": [1000, 750, 500, 250],
19-
"scheduler_type": "WanStepDistillScheduler",
2018
"lora_configs": [
2119
{
2220
"path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors",

configs/wan22/wan_animate_lora_dist.json

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
"refert_num": 1,
1616
"replace_flag": false,
1717
"fps": 30,
18-
"denoising_step_list": [1000, 750, 500, 250],
19-
"scheduler_type": "WanStepDistillScheduler",
2018
"lora_configs": [
2119
{
2220
"path": "lightx2v/Wan2.1-Distill-Loras/wan2.1_i2v_lora_rank64_lightx2v_4step.safetensors",

lightx2v/models/runners/wan/wan_animate_runner.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,40 @@ def padding_resize(
106106

107107
return img_pad
108108

109+
def use_auto_target_shape(self):
110+
return self.config.get("auto_target_shape", True)
111+
112+
def get_comfy_target_shape(self):
113+
height = (int(self.config["target_height"]) // 16) * 16
114+
width = (int(self.config["target_width"]) // 16) * 16
115+
if height <= 0 or width <= 0:
116+
raise ValueError(f"Invalid WanAnimate target shape: height={height}, width={width}")
117+
return height, width
118+
119+
def center_crop_to_aspect(self, img, height, width):
120+
ori_height, ori_width = img.shape[:2]
121+
target_aspect = width / height
122+
ori_aspect = ori_width / ori_height
123+
if ori_aspect > target_aspect:
124+
crop_width = max(1, round(ori_height * target_aspect))
125+
x0 = max(0, (ori_width - crop_width) // 2)
126+
img = img[:, x0 : x0 + crop_width]
127+
elif ori_aspect < target_aspect:
128+
crop_height = max(1, round(ori_width / target_aspect))
129+
y0 = max(0, (ori_height - crop_height) // 2)
130+
img = img[y0 : y0 + crop_height]
131+
return img
132+
133+
def comfy_resize(self, img, height, width, interpolation=cv2.INTER_LANCZOS4, crop=None):
134+
if crop == "center":
135+
img = self.center_crop_to_aspect(img, height=height, width=width)
136+
if img.shape[0] == height and img.shape[1] == width:
137+
return img
138+
return cv2.resize(img, (width, height), interpolation=interpolation)
139+
140+
def comfy_resize_frames(self, frames, height, width, interpolation=cv2.INTER_LANCZOS4, crop=None):
141+
return np.stack([self.comfy_resize(frame, height, width, interpolation=interpolation, crop=crop) for frame in frames])
142+
109143
def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
110144
pose_video_reader = VideoReader(src_pose_path)
111145
pose_len = len(pose_video_reader)
@@ -118,7 +152,14 @@ def prepare_source(self, src_pose_path, src_face_path, src_ref_path):
118152
face_images = face_video_reader.get_batch(face_idxs).asnumpy()
119153
height, width = cond_images[0].shape[:2]
120154
refer_images = cv2.imread(src_ref_path)[..., ::-1]
121-
refer_images = self.padding_resize(refer_images, height=height, width=width)
155+
if self.use_auto_target_shape():
156+
refer_images = self.padding_resize(refer_images, height=height, width=width)
157+
else:
158+
target_height, target_width = self.get_comfy_target_shape()
159+
logger.info(f"WanAnimate uses config target shape: height={target_height}, width={target_width}")
160+
cond_images = self.comfy_resize_frames(cond_images, target_height, target_width)
161+
refer_images = self.comfy_resize(refer_images, target_height, target_width)
162+
face_images = self.comfy_resize_frames(face_images, 512, 512, crop="center")
122163
return cond_images, face_images, refer_images
123164

124165
def prepare_source_for_replace(self, src_bg_path, src_mask_path):
@@ -132,6 +173,10 @@ def prepare_source_for_replace(self, src_bg_path, src_mask_path):
132173
mask_idxs = list(range(mask_len))
133174
mask_images = mask_video_reader.get_batch(mask_idxs).asnumpy()
134175
mask_images = mask_images[:, :, :, 0] / 255
176+
if not self.use_auto_target_shape():
177+
target_height, target_width = self.get_comfy_target_shape()
178+
bg_images = self.comfy_resize_frames(bg_images, target_height, target_width)
179+
mask_images = self.comfy_resize_frames(mask_images, target_height, target_width, interpolation=cv2.INTER_NEAREST)
135180
return bg_images, mask_images
136181

137182
@ProfilingContext4DebugL2("Run Image Encoders")

scripts/platforms/mlu/run_wan22_animate_dist.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ python ${lightx2v_path}/tools/preprocess/preprocess_data.py \
2525
--resolution_area 1280 720 \
2626
--retarget_flag \
2727

28-
torchrun --nproc_per_node=8 -m lightx2v.infer \
28+
torchrun --nproc_per_node=2 -m lightx2v.infer \
2929
--model_cls wan2.2_animate \
3030
--task animate \
3131
--model_path $model_path \

0 commit comments

Comments
 (0)