Skip to content

Commit c795e35

Browse files
lzwslzw478614@alibaba-inc.com
andauthored
add wan2.2-fun-A14B inp, control and control-camera (#839)
* update wan2.2-fun * update wan2.2-fun * update wan2.2-fun * add examples * update wan2.2-fun * update wan2.2-fun * Rename Wan2.2-Fun-A14B-Inp.py to Wan2.2-Fun-A14B-InP.py --------- Co-authored-by: lzw478614@alibaba-inc.com <lzw478614@alibaba-inc.com>
1 parent 6a45815 commit c795e35

File tree

7 files changed

+183
-10
lines changed

7 files changed

+183
-10
lines changed

diffsynth/configs/model_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@
150150
(None, "b61c605c2adbd23124d152ed28e049ae", ["wan_video_dit"], [WanModel], "civitai"),
151151
(None, "1f5ab7703c6fc803fdded85ff040c316", ["wan_video_dit"], [WanModel], "civitai"),
152152
(None, "5b013604280dd715f8457c6ed6d6a626", ["wan_video_dit"], [WanModel], "civitai"),
153+
(None, "2267d489f0ceb9f21836532952852ee5", ["wan_video_dit"], [WanModel], "civitai"),
154+
(None, "47dbeab5e560db3180adf51dc0232fb1", ["wan_video_dit"], [WanModel], "civitai"),
153155
(None, "a61453409b67cd3246cf0c3bebad47ba", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
154156
(None, "7a513e1f257a861512b1afd387a8ecd9", ["wan_video_dit", "wan_video_vace"], [WanModel, VaceWanModel], "civitai"),
155157
(None, "cb104773c6c2cb6df4f9529ad5c60d0b", ["wan_video_dit"], [WanModel], "diffusers"),

diffsynth/models/wan_video_camera_controller.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def process_pose_file(cam_params, width=672, height=384, original_pose_width=128
182182

183183

184184
def generate_camera_coordinates(
185-
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown"],
185+
direction: Literal["Left", "Right", "Up", "Down", "LeftUp", "LeftDown", "RightUp", "RightDown", "In", "Out"],
186186
length: int,
187187
speed: float = 1/54,
188188
origin=(0, 0.532139961, 0.946026558, 0.5, 0.5, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0)
@@ -198,5 +198,9 @@ def generate_camera_coordinates(
198198
coor[13] += speed
199199
if "Down" in direction:
200200
coor[13] -= speed
201+
if "In" in direction:
202+
coor[18] -= speed
203+
if "Out" in direction:
204+
coor[18] += speed
201205
coordinates.append(coor)
202206
return coordinates

diffsynth/models/wan_video_dit.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ def __init__(
294294
):
295295
super().__init__()
296296
self.dim = dim
297+
self.in_dim = in_dim
297298
self.freq_dim = freq_dim
298299
self.has_image_input = has_image_input
299300
self.patch_size = patch_size
@@ -713,6 +714,42 @@ def from_civitai(self, state_dict):
713714
"eps": 1e-6,
714715
"require_clip_embedding": False,
715716
}
717+
elif hash_state_dict_keys(state_dict) == "2267d489f0ceb9f21836532952852ee5":
718+
# Wan2.2-Fun-A14B-Control
719+
config = {
720+
"has_image_input": False,
721+
"patch_size": [1, 2, 2],
722+
"in_dim": 52,
723+
"dim": 5120,
724+
"ffn_dim": 13824,
725+
"freq_dim": 256,
726+
"text_dim": 4096,
727+
"out_dim": 16,
728+
"num_heads": 40,
729+
"num_layers": 40,
730+
"eps": 1e-6,
731+
"has_ref_conv": True,
732+
"require_clip_embedding": False,
733+
}
734+
elif hash_state_dict_keys(state_dict) == "47dbeab5e560db3180adf51dc0232fb1":
735+
# Wan2.2-Fun-A14B-Control-Camera
736+
config = {
737+
"has_image_input": False,
738+
"patch_size": [1, 2, 2],
739+
"in_dim": 36,
740+
"dim": 5120,
741+
"ffn_dim": 13824,
742+
"freq_dim": 256,
743+
"text_dim": 4096,
744+
"out_dim": 16,
745+
"num_heads": 40,
746+
"num_layers": 40,
747+
"eps": 1e-6,
748+
"has_ref_conv": False,
749+
"add_control_adapter": True,
750+
"in_dim_control_adapter": 24,
751+
"require_clip_embedding": False,
752+
}
716753
else:
717754
config = {}
718755
return state_dict, config

diffsynth/pipelines/wan_video_new.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -663,22 +663,23 @@ def process(self, pipe: WanVideoPipeline, input_image, latents, height, width, t
663663
class WanVideoUnit_FunControl(PipelineUnit):
664664
def __init__(self):
665665
super().__init__(
666-
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y"),
666+
input_params=("control_video", "num_frames", "height", "width", "tiled", "tile_size", "tile_stride", "clip_feature", "y", "latents"),
667667
onload_model_names=("vae",)
668668
)
669669

670-
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y):
670+
def process(self, pipe: WanVideoPipeline, control_video, num_frames, height, width, tiled, tile_size, tile_stride, clip_feature, y, latents):
671671
if control_video is None:
672672
return {}
673673
pipe.load_models_to_device(self.onload_model_names)
674674
control_video = pipe.preprocess_video(control_video)
675675
control_latents = pipe.vae.encode(control_video, device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(dtype=pipe.torch_dtype, device=pipe.device)
676676
control_latents = control_latents.to(dtype=pipe.torch_dtype, device=pipe.device)
677+
y_dim = pipe.dit.in_dim-control_latents.shape[1]-latents.shape[1]
677678
if clip_feature is None or y is None:
678679
clip_feature = torch.zeros((1, 257, 1280), dtype=pipe.torch_dtype, device=pipe.device)
679-
y = torch.zeros((1, 16, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
680+
y = torch.zeros((1, y_dim, (num_frames - 1) // 4 + 1, height//8, width//8), dtype=pipe.torch_dtype, device=pipe.device)
680681
else:
681-
y = y[:, -16:]
682+
y = y[:, -y_dim:]
682683
y = torch.concat([control_latents, y], dim=1)
683684
return {"clip_feature": clip_feature, "y": y}
684685

@@ -698,6 +699,8 @@ def process(self, pipe: WanVideoPipeline, reference_image, height, width):
698699
reference_image = reference_image.resize((width, height))
699700
reference_latents = pipe.preprocess_video([reference_image])
700701
reference_latents = pipe.vae.encode(reference_latents, device=pipe.device)
702+
if pipe.image_encoder is None:
703+
return {"reference_latents": reference_latents}
701704
clip_feature = pipe.preprocess_image(reference_image)
702705
clip_feature = pipe.image_encoder.encode_image([clip_feature])
703706
return {"reference_latents": reference_latents, "clip_feature": clip_feature}
@@ -707,13 +710,14 @@ def process(self, pipe: WanVideoPipeline, reference_image, height, width):
707710
class WanVideoUnit_FunCameraControl(PipelineUnit):
708711
def __init__(self):
709712
super().__init__(
710-
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image"),
713+
input_params=("height", "width", "num_frames", "camera_control_direction", "camera_control_speed", "camera_control_origin", "latents", "input_image", "tiled", "tile_size", "tile_stride"),
711714
onload_model_names=("vae",)
712715
)
713716

714-
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image):
717+
def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_control_direction, camera_control_speed, camera_control_origin, latents, input_image, tiled, tile_size, tile_stride):
715718
if camera_control_direction is None:
716719
return {}
720+
pipe.load_models_to_device(self.onload_model_names)
717721
camera_control_plucker_embedding = pipe.dit.control_adapter.process_camera_coordinates(
718722
camera_control_direction, num_frames, height, width, camera_control_speed, camera_control_origin)
719723

@@ -728,14 +732,27 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames, camera_cont
728732
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, 4, c, h, w).transpose(2, 3)
729733
control_camera_latents = control_camera_latents.contiguous().view(b, f // 4, c * 4, h, w).transpose(1, 2)
730734
control_camera_latents_input = control_camera_latents.to(device=pipe.device, dtype=pipe.torch_dtype)
731-
735+
732736
input_image = input_image.resize((width, height))
733737
input_latents = pipe.preprocess_video([input_image])
734-
pipe.load_models_to_device(self.onload_model_names)
735738
input_latents = pipe.vae.encode(input_latents, device=pipe.device)
736739
y = torch.zeros_like(latents).to(pipe.device)
737740
y[:, :, :1] = input_latents
738741
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
742+
743+
if y.shape[1] != pipe.dit.in_dim - latents.shape[1]:
744+
image = pipe.preprocess_image(input_image.resize((width, height))).to(pipe.device)
745+
vae_input = torch.concat([image.transpose(0, 1), torch.zeros(3, num_frames-1, height, width).to(image.device)], dim=1)
746+
y = pipe.vae.encode([vae_input.to(dtype=pipe.torch_dtype, device=pipe.device)], device=pipe.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)[0]
747+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
748+
msk = torch.ones(1, num_frames, height//8, width//8, device=pipe.device)
749+
msk[:, 1:] = 0
750+
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
751+
msk = msk.view(1, msk.shape[1] // 4, 4, height//8, width//8)
752+
msk = msk.transpose(1, 2)[0]
753+
y = torch.cat([msk,y])
754+
y = y.unsqueeze(0)
755+
y = y.to(dtype=pipe.torch_dtype, device=pipe.device)
739756
return {"control_camera_latents_input": control_camera_latents_input, "y": y}
740757

741758

@@ -1048,7 +1065,7 @@ def model_fn_wan_video(
10481065
if clip_feature is not None and dit.require_clip_embedding:
10491066
clip_embdding = dit.img_emb(clip_feature)
10501067
context = torch.cat([clip_embdding, context], dim=1)
1051-
1068+
10521069
# Add camera control
10531070
x, (f, h, w) = dit.patchify(x, control_camera_latents_input)
10541071

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
from diffsynth import save_video,VideoData
3+
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
4+
from PIL import Image
5+
from modelscope import dataset_snapshot_download
6+
7+
pipe = WanVideoPipeline.from_pretrained(
8+
torch_dtype=torch.bfloat16,
9+
device="cuda",
10+
model_configs=[
11+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
12+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
13+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
14+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control-Camera", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
15+
],
16+
)
17+
pipe.enable_vram_management()
18+
19+
20+
dataset_snapshot_download(
21+
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
22+
local_dir="./",
23+
allow_file_pattern=f"data/examples/wan/input_image.jpg"
24+
)
25+
input_image = Image.open("data/examples/wan/input_image.jpg")
26+
27+
video = pipe(
28+
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
29+
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
30+
seed=0, tiled=True,
31+
input_image=input_image,
32+
camera_control_direction="Left", camera_control_speed=0.01,
33+
)
34+
save_video(video, "video_left.mp4", fps=15, quality=5)
35+
36+
video = pipe(
37+
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
38+
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
39+
seed=0, tiled=True,
40+
input_image=input_image,
41+
camera_control_direction="Up", camera_control_speed=0.01,
42+
)
43+
save_video(video, "video_up.mp4", fps=15, quality=5)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
from diffsynth import save_video,VideoData
3+
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
4+
from PIL import Image
5+
from modelscope import dataset_snapshot_download
6+
7+
pipe = WanVideoPipeline.from_pretrained(
8+
torch_dtype=torch.bfloat16,
9+
device="cuda",
10+
model_configs=[
11+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
12+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
13+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
14+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-Control", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
15+
],
16+
)
17+
pipe.enable_vram_management()
18+
19+
dataset_snapshot_download(
20+
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
21+
local_dir="./",
22+
allow_file_pattern=["data/examples/wan/control_video.mp4", "data/examples/wan/reference_image_girl.png"]
23+
)
24+
25+
# Control video
26+
control_video = VideoData("data/examples/wan/control_video.mp4", height=832, width=576)
27+
reference_image = Image.open("data/examples/wan/reference_image_girl.png").resize((576, 832))
28+
video = pipe(
29+
prompt="扁平风格动漫,一位长发少女优雅起舞。她五官精致,大眼睛明亮有神,黑色长发柔顺光泽。身穿淡蓝色T恤和深蓝色牛仔短裤。背景是粉色。",
30+
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
31+
control_video=control_video, reference_image=reference_image,
32+
height=832, width=576, num_frames=49,
33+
seed=1, tiled=True
34+
)
35+
save_video(video, "video.mp4", fps=15, quality=5)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
from diffsynth import save_video
3+
from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig
4+
from PIL import Image
5+
from modelscope import dataset_snapshot_download
6+
7+
pipe = WanVideoPipeline.from_pretrained(
8+
torch_dtype=torch.bfloat16,
9+
device="cuda",
10+
model_configs=[
11+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="high_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
12+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="low_noise_model/diffusion_pytorch_model*.safetensors", offload_device="cpu"),
13+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="models_t5_umt5-xxl-enc-bf16.pth", offload_device="cpu"),
14+
ModelConfig(model_id="PAI/Wan2.2-Fun-A14B-InP", origin_file_pattern="Wan2.1_VAE.pth", offload_device="cpu"),
15+
],
16+
)
17+
pipe.enable_vram_management()
18+
19+
dataset_snapshot_download(
20+
dataset_id="DiffSynth-Studio/examples_in_diffsynth",
21+
local_dir="./",
22+
allow_file_pattern=f"data/examples/wan/input_image.jpg"
23+
)
24+
image = Image.open("data/examples/wan/input_image.jpg")
25+
26+
# First and last frame to video
27+
video = pipe(
28+
prompt="一艘小船正勇敢地乘风破浪前行。蔚蓝的大海波涛汹涌,白色的浪花拍打着船身,但小船毫不畏惧,坚定地驶向远方。阳光洒在水面上,闪烁着金色的光芒,为这壮丽的场景增添了一抹温暖。镜头拉近,可以看到船上的旗帜迎风飘扬,象征着不屈的精神与冒险的勇气。这段画面充满力量,激励人心,展现了面对挑战时的无畏与执着。",
29+
negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
30+
input_image=image,
31+
seed=0, tiled=True,
32+
# You can input `end_image=xxx` to control the last frame of the video.
33+
# The model will automatically generate the dynamic content between `input_image` and `end_image`.
34+
)
35+
save_video(video, "video.mp4", fps=15, quality=5)

0 commit comments

Comments
 (0)