forked from GuardSkill/ComfyUI-DanceEverywhere
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdance_composite.py
More file actions
162 lines (137 loc) · 6.57 KB
/
dance_composite.py
File metadata and controls
162 lines (137 loc) · 6.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import warnings
import numpy as np
import cv2
import torch
from PIL import Image, ImageFilter
class DanceVideoComposite:
"""
把带 mask 的跳舞人物视频合成到背景图上。
参数极简:
- 落脚点 (pos_x, pos_y):人物双脚落地位置,0-1 相对坐标
- scale:整体缩放,相对于背景高度的比例(0.3 = 人物高度是背景的 30%)
- height_ratio:高矮微调,>1 拉高,<1 压矮,宽度不变
- edge_feather:mask 边缘羽化半径
"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"person_video": ("IMAGE",),
"person_mask": ("MASK",),
"background": ("IMAGE",),
"pos_x": ("FLOAT", {
"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.005,
"display": "slider", "tooltip": "落脚点横坐标(人物水平中心)"
}),
"pos_y": ("FLOAT", {
"default": 0.85, "min": 0.0, "max": 1.0, "step": 0.005,
"display": "slider", "tooltip": "落脚点纵坐标(人物脚底)"
}),
"scale": ("FLOAT", {
"default": 0.9, "min": 0.01, "max": 2.0, "step": 0.005,
"display": "slider",
"tooltip": "人物整体大小,相对于背景高度的比例(0.35 = 背景高的 35%)"
}),
"height_ratio": ("FLOAT", {
"default": 1.0, "min": 0.2, "max": 3.0, "step": 0.01,
"display": "slider",
"tooltip": ">1 拉高人物,<1 压矮人物(宽度保持不变)"
}),
"edge_feather": ("INT", {
"default": 6, "min": 0, "max": 80, "step": 1,
"tooltip": "mask 边缘羽化半径(像素,相对于原始人物帧)"
}),
"brightness_blend": ("FLOAT", {
"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01,
"display": "slider",
"tooltip": "人物亮度向背景局部亮度靠拢的程度(0=不调整,1=完全匹配)"
}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES = ("composited_video",)
FUNCTION = "composite"
CATEGORY = "🎬 Video/Composite"
def composite(
self,
person_video: torch.Tensor, # [N, pH, pW, C]
person_mask: torch.Tensor, # [N, pH, pW] 或 [1, pH, pW]
background: torch.Tensor, # [1, bH, bW, C]
pos_x: float,
pos_y: float,
scale: float,
height_ratio: float,
edge_feather: int,
brightness_blend: float,
):
N = person_video.shape[0]
N_mask = person_mask.shape[0]
bg_np = background[0].cpu().numpy()[..., :3] # [bH, bW, 3]
bH, bW = bg_np.shape[:2]
# 目标尺寸:宽度按原比例,高度额外用 height_ratio 微调
orig_pH, orig_pW = person_video.shape[1], person_video.shape[2]
base_h = max(4, int(scale * bH)) # 基础高度
target_w = max(4, int(base_h * orig_pW / orig_pH)) # 保持宽高比
target_h = max(4, int(base_h * height_ratio)) # 高矮微调
# 落脚点:人物水平居中于 pos_x,脚底对齐 pos_y
paste_x = int(pos_x * bW) - target_w // 2
paste_y = int(pos_y * bH) - target_h
frames_out = []
for i in range(N):
frame_np = person_video[i].cpu().numpy() # [pH, pW, C]
mask_np = person_mask[min(i, N_mask - 1)].cpu().numpy() # [pH, pW]
frame_r, mask_r = self._resize(frame_np, mask_np, target_h, target_w)
mask_f = self._feather(mask_r, edge_feather)
result = self._composite_frame(bg_np, frame_r, mask_f, paste_x, paste_y, brightness_blend)
frames_out.append(torch.from_numpy(result))
return (torch.stack(frames_out),) # [N, bH, bW, 3]
# ------------------------------------------------------------------
def _resize(self, frame_np, mask_np, target_h, target_w):
frame_u8 = (frame_np[..., :3] * 255).clip(0, 255).astype(np.uint8)
frame_r = cv2.resize(frame_u8, (target_w, target_h), interpolation=cv2.INTER_LANCZOS4)
frame_r = frame_r.astype(np.float32) / 255.0
mask_u8 = (mask_np * 255).clip(0, 255).astype(np.uint8)
mask_r = cv2.resize(mask_u8, (target_w, target_h), interpolation=cv2.INTER_LANCZOS4)
mask_r = mask_r.astype(np.float32) / 255.0
return frame_r, mask_r
def _feather(self, mask_np, radius):
if radius <= 0:
return mask_np
mask_u8 = (mask_np * 255).clip(0, 255).astype(np.uint8)
pil_m = Image.fromarray(mask_u8, mode="L")
blurred = pil_m.filter(ImageFilter.GaussianBlur(radius=radius))
return np.array(blurred).astype(np.float32) / 255.0
def _composite_frame(self, bg_np, person_np, mask_np, paste_x, paste_y, brightness_blend=0.0):
result = bg_np.copy()
bH, bW = bg_np.shape[:2]
tH, tW = person_np.shape[:2]
bg_x0 = max(0, paste_x)
bg_y0 = max(0, paste_y)
bg_x1 = min(bW, paste_x + tW)
bg_y1 = min(bH, paste_y + tH)
if bg_x0 >= bg_x1 or bg_y0 >= bg_y1:
warnings.warn(
f"DanceVideoComposite: 人物超出背景范围 "
f"(paste_x={paste_x}, paste_y={paste_y}, size={tW}x{tH}),"
"请调整 pos_x / pos_y。"
)
return result
p_x0 = bg_x0 - paste_x
p_y0 = bg_y0 - paste_y
p_x1 = p_x0 + (bg_x1 - bg_x0)
p_y1 = p_y0 + (bg_y1 - bg_y0)
person_sl = person_np[p_y0:p_y1, p_x0:p_x1] # [rH, rW, 3]
mask_sl = mask_np [p_y0:p_y1, p_x0:p_x1] # [rH, rW]
bg_sl = result [bg_y0:bg_y1, bg_x0:bg_x1, :3]
if brightness_blend > 0.0:
p_mean = person_sl[mask_sl > 0.5].mean() if (mask_sl > 0.5).any() else 1e-5
b_mean = bg_sl.mean()
if p_mean > 1e-5:
ratio = b_mean / p_mean
adjusted = np.clip(person_sl * ratio, 0.0, 1.0)
person_sl = person_sl * (1.0 - brightness_blend) + adjusted * brightness_blend
alpha = mask_sl[:, :, np.newaxis]
result[bg_y0:bg_y1, bg_x0:bg_x1, :3] = person_sl * alpha + bg_sl * (1.0 - alpha)
return result
NODE_CLASS_MAPPINGS = {"DanceVideoComposite": DanceVideoComposite}
NODE_DISPLAY_NAME_MAPPINGS = {"DanceVideoComposite": "🕺 Dance Video Composite"}