Skip to content

Commit 46feb8c

Browse files
Complete The LightX2V's Support To Motus with i2v task. (#992)
Add Motus feature to LightX2V with i2v task, where "i" here represents the first frame. --------- Co-authored-by: Shiqiao Gu (谷石桥) <77222802+gushiqiao@users.noreply.github.com>
1 parent d203934 commit 46feb8c

24 files changed

Lines changed: 2241 additions & 0 deletions

configs/motus/motus_i2v.json

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"wan_path": "/path/to/Wan2.2-TI2V-5B",
3+
"vlm_path": "/path/to/Qwen3-VL-2B-Instruct",
4+
"infer_steps": 10,
5+
"text_len": 512,
6+
"target_video_length": 9,
7+
"target_height": 384,
8+
"target_width": 320,
9+
"num_channels_latents": 48,
10+
"sample_guide_scale": 1.0,
11+
"patch_size": [1, 2, 2],
12+
"vae_stride": [4, 16, 16],
13+
"sample_shift": 5.0,
14+
"feature_caching": "NoCaching",
15+
"use_image_encoder": false,
16+
"enable_cfg": false,
17+
"attention_type": "flash_attn2",
18+
"self_joint_attn_type": "flash_attn2",
19+
"cross_attn_type": "flash_attn2",
20+
"global_downsample_rate": 3,
21+
"video_action_freq_ratio": 2,
22+
"num_video_frames": 8,
23+
"video_height": 384,
24+
"video_width": 320,
25+
"fps": 4,
26+
"motus_quantized": false,
27+
"motus_quant_scheme": "Default",
28+
"load_pretrained_backbones": false,
29+
"training_mode": "finetune",
30+
"action_state_dim": 14,
31+
"action_dim": 14,
32+
"action_expert_dim": 1024,
33+
"action_expert_ffn_dim_multiplier": 4,
34+
"und_expert_hidden_size": 512,
35+
"und_expert_ffn_dim_multiplier": 4
36+
}

lightx2v/infer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from lightx2v.common.ops import *
99
from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401
10+
from lightx2v.models.runners.motus.motus_runner import MotusRunner # noqa: F401
1011
from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401
1112
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_distill_runner import HunyuanVideo15DistillRunner # noqa: F401
1213
from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401
@@ -82,6 +83,7 @@ def main():
8283
"bagel",
8384
"seedvr2",
8485
"neopp",
86+
"motus",
8587
"lingbot_world_fast",
8688
"worldmirror",
8789
],
@@ -104,6 +106,7 @@ def main():
104106
default="",
105107
help="The path to input image file(s) for image-to-video (i2v) or image-to-audio-video (i2av) task. Multiple paths should be comma-separated. Example: 'path1.jpg,path2.jpg'",
106108
)
109+
parser.add_argument("--state_path", type=str, default="", help="The path to input robot state file for Motus i2v inference.")
107110
parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task")
108111
parser.add_argument(
109112
"--audio_path",
@@ -191,6 +194,7 @@ def main():
191194
parser.add_argument("--wm_ckpt_path", type=str, default=None, help="(worldmirror/recon) Optional .ckpt/.safetensors (pair with --wm_config_path).")
192195

193196
parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
197+
parser.add_argument("--save_action_path", type=str, default=None, help="The path to save action predictions for Motus.")
194198
parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
195199
parser.add_argument("--target_shape", type=int, nargs="+", default=[], help="Set return video or image shape")
196200
parser.add_argument("--target_video_length", type=int, default=81, help="The target video length for each generated clip")
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .model import MotusModel
2+
from .primitives import sinusoidal_embedding_1d
3+
4+
__all__ = [
5+
"MotusModel",
6+
"sinusoidal_embedding_1d",
7+
]
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import cv2
2+
import numpy as np
3+
4+
5+
def resize_with_padding(frame: np.ndarray, target_size: tuple[int, int]) -> np.ndarray:
6+
target_height, target_width = target_size
7+
original_height, original_width = frame.shape[:2]
8+
9+
scale = min(target_height / original_height, target_width / original_width)
10+
new_height = int(original_height * scale)
11+
new_width = int(original_width * scale)
12+
13+
resized_frame = cv2.resize(frame, (new_width, new_height))
14+
padded_frame = np.zeros((target_height, target_width, frame.shape[2]), dtype=frame.dtype)
15+
16+
y_offset = (target_height - new_height) // 2
17+
x_offset = (target_width - new_width) // 2
18+
padded_frame[y_offset : y_offset + new_height, x_offset : x_offset + new_width] = resized_frame
19+
return padded_frame
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .post_infer import MotusPostInfer
2+
from .pre_infer import MotusPreInfer
3+
from .transformer_infer import MotusTransformerInfer
4+
5+
__all__ = ["MotusPreInfer", "MotusTransformerInfer", "MotusPostInfer"]
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from dataclasses import dataclass, field
2+
from typing import Any
3+
4+
import torch
5+
6+
from lightx2v.models.networks.wan.infer.module_io import WanPreInferModuleOutput
7+
8+
9+
@dataclass(kw_only=True)
10+
class MotusPreInferModuleOutput(WanPreInferModuleOutput):
11+
state: torch.Tensor
12+
first_frame: torch.Tensor
13+
instruction: str
14+
t5_embeddings: list[torch.Tensor]
15+
vlm_inputs: list[dict[str, Any]]
16+
image_context: torch.Tensor | None
17+
und_tokens: torch.Tensor
18+
condition_frame_latent: torch.Tensor
19+
adapter_args: dict[str, Any] = field(default_factory=dict)
20+
conditional_dict: dict[str, Any] = field(default_factory=dict)
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import torch
2+
3+
4+
class MotusPostInfer:
5+
def __init__(self, model, config):
6+
self.model = model
7+
self.config = config
8+
self.scheduler = None
9+
10+
def set_scheduler(self, scheduler):
11+
self.scheduler = scheduler
12+
13+
@torch.no_grad()
14+
def infer(self, action_latents: torch.Tensor, pre_infer_out):
15+
del pre_infer_out
16+
return self.model.denormalize_actions(action_latents.float()).squeeze(0)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
3+
from lightx2v.models.networks.wan.infer.module_io import GridOutput
4+
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
5+
6+
from .module_io import MotusPreInferModuleOutput
7+
8+
9+
class MotusPreInfer(WanPreInfer):
10+
def __init__(self, model, config):
11+
super().__init__(config)
12+
self.model = model
13+
self.scheduler = None
14+
15+
def set_scheduler(self, scheduler):
16+
self.scheduler = scheduler
17+
18+
@torch.no_grad()
19+
def infer(self, weights, inputs, kv_start=0, kv_end=0):
20+
del weights, kv_start, kv_end
21+
if self.scheduler is None:
22+
raise RuntimeError("MotusPreInfer requires a scheduler before infer().")
23+
24+
first_frame = inputs["motus_first_frame"]
25+
state = inputs["motus_state"]
26+
instruction = inputs["motus_instruction"]
27+
t5_context = inputs["motus_t5_embeddings"]
28+
processed_t5_context = inputs["motus_processed_t5_context"]
29+
vlm_inputs = inputs["motus_vlm_inputs"]
30+
image_context = inputs["motus_image_context"]
31+
und_tokens = inputs["motus_und_tokens"]
32+
33+
video_latents = self.scheduler.video_latents
34+
if video_latents.dim() != 5:
35+
raise RuntimeError(f"Expected video latents with shape [B, C, T, H, W], got {tuple(video_latents.shape)}")
36+
batch_size = state.shape[0]
37+
_, _, latent_t, latent_h, latent_w = video_latents.shape
38+
grid_sizes = torch.tensor(
39+
[[latent_t, latent_h // self.model.video_backbone.patch_size[1], latent_w // self.model.video_backbone.patch_size[2]]],
40+
dtype=torch.long,
41+
device=state.device,
42+
).expand(batch_size, -1)
43+
grid_output = GridOutput(
44+
tensor=grid_sizes,
45+
tuple=tuple(int(v) for v in grid_sizes[0].tolist()),
46+
)
47+
48+
if self.cos_sin is None or self.grid_sizes != grid_output.tuple:
49+
self.grid_sizes = grid_output.tuple
50+
self.cos_sin = self.prepare_cos_sin(grid_output.tuple, self.freqs.clone())
51+
52+
dummy_embed = torch.empty(0, device=state.device, dtype=processed_t5_context.dtype)
53+
54+
return MotusPreInferModuleOutput(
55+
embed=dummy_embed,
56+
grid_sizes=grid_output,
57+
x=self.scheduler.video_latents,
58+
embed0=dummy_embed,
59+
context=processed_t5_context,
60+
cos_sin=self.cos_sin,
61+
first_frame=first_frame,
62+
state=state,
63+
instruction=instruction,
64+
t5_embeddings=t5_context,
65+
vlm_inputs=vlm_inputs,
66+
image_context=image_context,
67+
und_tokens=und_tokens,
68+
condition_frame_latent=self.scheduler.condition_frame_latent,
69+
adapter_args={"instruction": instruction},
70+
)

0 commit comments

Comments
 (0)