diff --git a/configs/motus/motus_i2v.json b/configs/motus/motus_i2v.json new file mode 100644 index 000000000..68cbb5d0b --- /dev/null +++ b/configs/motus/motus_i2v.json @@ -0,0 +1,30 @@ +{ + "checkpoint_path": "/path/to/MotusModel", + "wan_path": "/path/to/Wan2.2-TI2V-5B", + "vlm_path": "/path/to/Qwen3-VL-2B-Instruct", + "infer_steps": 10, + "num_inference_steps": 10, + "target_video_length": 9, + "target_height": 384, + "target_width": 320, + "attention_type": "flash_attn2", + "self_attn_1_type": "flash_attn2", + "self_attn_2_type": "flash_attn2", + "cross_attn_1_type": "flash_attn2", + "global_downsample_rate": 3, + "video_action_freq_ratio": 2, + "num_video_frames": 8, + "video_height": 384, + "video_width": 320, + "fps": 4, + "motus_quantized": false, + "motus_quant_scheme": "Default", + "load_pretrained_backbones": false, + "training_mode": "finetune", + "action_state_dim": 14, + "action_dim": 14, + "action_expert_dim": 1024, + "action_expert_ffn_dim_multiplier": 4, + "und_expert_hidden_size": 512, + "und_expert_ffn_dim_multiplier": 4 +} diff --git a/lightx2v/infer.py b/lightx2v/infer.py index 2ebd770b4..03239e6bb 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -7,6 +7,7 @@ from lightx2v.common.ops import * from lightx2v.models.runners.bagel.bagel_runner import BagelRunner # noqa: F401 +from lightx2v.models.runners.motus.motus_runner import MotusRunner # noqa: F401 try: from lightx2v.models.runners.flux2_klein.flux2_klein_runner import Flux2KleinRunner # noqa: F401 @@ -82,6 +83,7 @@ def main(): "bagel", "seedvr2", "neopp", + "motus", ], default="wan2.1", ) @@ -102,6 +104,7 @@ def main(): default="", help="The path to input image file(s) for image-to-video (i2v) or image-to-audio-video (i2av) task. Multiple paths should be comma-separated. Example: 'path1.jpg,path2.jpg'", ) + parser.add_argument("--state_path", type=str, default="", help="The path to input robot state file for Motus i2v inference.") parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task") parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file or directory for audio-to-video (s2v) task") parser.add_argument("--image_strength", type=float, default=1.0, help="The strength of the image-to-audio-video (i2av) task") @@ -167,6 +170,7 @@ def main(): help="Path to action model checkpoint for WorldPlay models.", ) parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file") + parser.add_argument("--save_action_path", type=str, default=None, help="The path to save action predictions for Motus.") parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)") parser.add_argument("--target_shape", type=int, nargs="+", default=[], help="Set return video or image shape") parser.add_argument("--target_video_length", type=int, default=81, help="The target video length for each generated clip") diff --git a/lightx2v/models/networks/motus/__init__.py b/lightx2v/models/networks/motus/__init__.py new file mode 100644 index 000000000..15f71005a --- /dev/null +++ b/lightx2v/models/networks/motus/__init__.py @@ -0,0 +1,21 @@ +from .action_expert import ActionExpert, ActionExpertConfig +from .core import Motus, MotusConfig +from .primitives import WanLayerNorm, WanRMSNorm, rope_apply, sinusoidal_embedding_1d +from .t5 import T5EncoderModel +from .und_expert import UndExpert, UndExpertConfig +from .wan_model import WanVideoModel + +__all__ = [ + "Motus", + "MotusConfig", + "WanVideoModel", + "ActionExpert", + "ActionExpertConfig", + "UndExpert", + "UndExpertConfig", + "T5EncoderModel", + "WanLayerNorm", + "WanRMSNorm", + "sinusoidal_embedding_1d", + "rope_apply", +] diff --git a/lightx2v/models/networks/motus/action_expert.py b/lightx2v/models/networks/motus/action_expert.py new file mode 100644 index 000000000..004177abb --- /dev/null +++ b/lightx2v/models/networks/motus/action_expert.py @@ -0,0 +1,144 @@ +import logging +import re +from dataclasses import dataclass + +import numpy as np +import torch +import torch.nn as nn + +from .primitives import WanLayerNorm, WanRMSNorm + +logger = logging.getLogger(__name__) + + +def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos): + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega + if isinstance(pos, torch.Tensor): + pos = pos.cpu().numpy() + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) + emb = np.concatenate([np.sin(out), np.cos(out)], axis=1) + return torch.from_numpy(emb).float() + + +@dataclass +class ActionExpertConfig: + dim: int = 1024 + ffn_dim: int = 4096 + num_layers: int = 30 + state_dim: int = 14 + action_dim: int = 14 + chunk_size: int = 16 + video_feature_dim: int = 3072 + causal: bool = False + num_registers: int = 4 + eps: float = 1e-6 + training_mode: str = "finetune" + + def __post_init__(self): + assert self.chunk_size >= 2 + + +def build_mlp(projector_type, in_features, out_features): + if projector_type == "linear": + return nn.Linear(in_features, out_features) + mlp_silu_match = re.match(r"^mlp(\d+)x_silu$", projector_type) + if mlp_silu_match: + mlp_depth = int(mlp_silu_match.group(1)) + modules = [nn.Linear(in_features, out_features)] + for _ in range(1, mlp_depth): + modules.append(nn.SiLU()) + modules.append(nn.Linear(out_features, out_features)) + return nn.Sequential(*modules) + raise ValueError(f"Unknown projector type: {projector_type}") + + +class StateActionEncoder(nn.Module): + def __init__(self, config: ActionExpertConfig): + super().__init__() + self.state_encoder = build_mlp("mlp3x_silu", config.state_dim, config.dim) + self.action_encoder = build_mlp("mlp3x_silu", config.action_dim, config.dim) + max_seq_len = config.chunk_size + 1 + config.num_registers + pos_embed = get_1d_sincos_pos_embed_from_grid(config.dim, np.arange(max_seq_len)) + self.register_buffer("pos_embedding", pos_embed.unsqueeze(0)) + + def forward(self, state_tokens: torch.Tensor, action_tokens: torch.Tensor, registers: torch.Tensor = None) -> torch.Tensor: + encoded = torch.cat([self.state_encoder(state_tokens), self.action_encoder(action_tokens)], dim=1) + if registers is not None: + encoded = torch.cat([encoded, registers], dim=1) + return encoded + self.pos_embedding[:, : encoded.shape[1], :] + + +class ActionEncoder(nn.Module): + def __init__(self, config: ActionExpertConfig): + super().__init__() + self.action_encoder = build_mlp("mlp3x_silu", config.action_dim, config.dim) + max_seq_len = config.chunk_size + config.num_registers + pos_embed = get_1d_sincos_pos_embed_from_grid(config.dim, np.arange(max_seq_len)) + self.register_buffer("pos_embedding", pos_embed.unsqueeze(0)) + + def forward(self, state_tokens: torch.Tensor, action_tokens: torch.Tensor, registers: torch.Tensor = None) -> torch.Tensor: + encoded = self.action_encoder(action_tokens) + if registers is not None: + encoded = torch.cat([encoded, registers], dim=1) + return encoded + self.pos_embedding[:, : encoded.shape[1], :] + + +class ActionExpertBlock(nn.Module): + def __init__(self, config: ActionExpertConfig, wan_config: dict): + super().__init__() + self.norm1 = WanLayerNorm(config.dim, eps=config.eps) + self.norm2 = WanLayerNorm(config.dim, eps=config.eps) + self.wan_num_heads = wan_config["num_heads"] + self.wan_head_dim = wan_config["head_dim"] + self.wan_dim = wan_config["dim"] + self.wan_action_qkv = nn.Parameter(torch.randn(3, self.wan_num_heads, config.dim, self.wan_head_dim) / (config.dim * self.wan_head_dim) ** 0.5) + self.wan_action_o = nn.Linear(self.wan_dim, config.dim, bias=False) + self.wan_action_norm_q = WanRMSNorm(self.wan_dim, eps=config.eps) + self.wan_action_norm_k = WanRMSNorm(self.wan_dim, eps=config.eps) + self.ffn = nn.Sequential(nn.Linear(config.dim, config.ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(config.ffn_dim, config.dim)) + self.modulation = nn.Parameter(torch.randn(1, 6, config.dim) / config.dim**0.5) + + +class ActionDecoder(nn.Module): + def __init__(self, config: ActionExpertConfig): + super().__init__() + self.norm = WanLayerNorm(config.dim, eps=config.eps) + self.action_head = build_mlp("mlp1x_silu", config.dim, config.action_dim) + self.modulation = nn.Parameter(torch.randn(1, 2, config.dim) / config.dim**0.5) + + def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor: + with torch.amp.autocast("cuda", dtype=torch.float32): + e0, e1 = (self.modulation.unsqueeze(0) + time_emb.unsqueeze(2)).chunk(2, dim=2) + z = self.norm(x) * (1 + e1.squeeze(2)) + e0.squeeze(2) + return self.action_head(z) + + +class ActionExpert(nn.Module): + def __init__(self, config: ActionExpertConfig, wan_config: dict = None): + super().__init__() + self.config = config + self.freq_dim = 256 + self.input_encoder = ActionEncoder(config) if config.training_mode == "pretrain" else StateActionEncoder(config) + self.time_embedding = nn.Sequential(nn.Linear(self.freq_dim, config.dim), nn.SiLU(), nn.Linear(config.dim, config.dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(config.dim, config.dim * 6)) + block_cfg = wan_config or {"dim": 3072, "num_heads": 24, "head_dim": 128} + self.blocks = nn.ModuleList([ActionExpertBlock(config, block_cfg) for _ in range(config.num_layers)]) + self.registers = nn.Parameter(torch.empty(1, config.num_registers, config.dim).normal_(std=0.02)) if config.num_registers > 0 else None + self.decoder = ActionDecoder(config) + self.initialize_weights() + + def initialize_weights(self): + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + nn.init.zeros_(self.decoder.action_head[-1].weight) + nn.init.zeros_(self.decoder.action_head[-1].bias) + for module in self.time_embedding.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=0.02) diff --git a/lightx2v/models/networks/motus/core.py b/lightx2v/models/networks/motus/core.py new file mode 100644 index 000000000..04670f8dd --- /dev/null +++ b/lightx2v/models/networks/motus/core.py @@ -0,0 +1,343 @@ +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional + +import torch +import torch.nn as nn +from transformers import AutoConfig, Qwen3VLForConditionalGeneration + +from .action_expert import ActionExpert, ActionExpertConfig +from .primitives import sinusoidal_embedding_1d +from .und_expert import UndExpert, UndExpertConfig +from .wan_model import WanVideoModel + +logger = logging.getLogger(__name__) + + +@dataclass +class MotusConfig: + wan_checkpoint_path: str + vae_path: str + wan_config_path: str + video_precision: str = "bfloat16" + vlm_checkpoint_path: str = "" + und_expert_hidden_size: int = 512 + und_expert_ffn_dim_multiplier: int = 4 + und_expert_norm_eps: float = 1e-5 + und_layers_to_extract: List[int] = None + vlm_adapter_input_dim: int = 2048 + vlm_adapter_projector_type: str = "mlp3x_silu" + num_layers: int = 30 + action_state_dim: int = 14 + action_dim: int = 14 + action_expert_dim: int = 1024 + action_expert_ffn_dim_multiplier: int = 4 + action_expert_norm_eps: float = 1e-6 + global_downsample_rate: int = 3 + video_action_freq_ratio: int = 2 + num_video_frames: int = 8 + video_height: int = 384 + video_width: int = 320 + batch_size: int = 1 + training_mode: str = "finetune" + load_pretrained_backbones: Optional[bool] = None + + def __post_init__(self): + self.action_chunk_size = self.num_video_frames * self.video_action_freq_ratio + if self.und_layers_to_extract is None: + self.und_layers_to_extract = list(range(self.num_layers)) + + +class VideoModule(nn.Module): + def __init__(self, video_model, dtype, device, grid_sizes): + super().__init__() + self.video_model = video_model + self.dtype = dtype + self.device = device + self.grid_sizes = grid_sizes + + def prepare_input(self, noisy_video_latent: torch.Tensor) -> torch.Tensor: + return self.video_model.wan_model.patch_embedding(noisy_video_latent).flatten(2).transpose(1, 2) + + def preprocess_t5_embeddings(self, language_embeddings) -> torch.Tensor: + if isinstance(language_embeddings, list): + text_len = self.video_model.wan_model.text_len + padded = [] + for emb in language_embeddings: + padded.append(torch.cat([emb, emb.new_zeros(text_len - emb.shape[0], emb.shape[1])]) if emb.shape[0] <= text_len else emb[:text_len]) + t5_context_raw = torch.stack(padded, dim=0) + else: + t5_context_raw = language_embeddings + return self.video_model.wan_model.text_embedding(t5_context_raw) + + def get_time_embedding(self, t_video: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: + if t_video.dim() == 1: + t_video = t_video.unsqueeze(1).expand(t_video.size(0), seq_len) + with torch.amp.autocast("cuda", dtype=torch.float32): + bt = t_video.size(0) + t_flat = t_video.flatten() + t_emb = self.video_model.wan_model.time_embedding(sinusoidal_embedding_1d(self.video_model.wan_model.freq_dim, t_flat).unflatten(0, (bt, seq_len)).float()) + t_emb_proj = self.video_model.wan_model.time_projection(t_emb).unflatten(2, (6, self.video_model.wan_model.dim)) + return t_emb, t_emb_proj + + def compute_adaln_modulation(self, video_adaln_params: torch.Tensor, layer_idx: int) -> tuple: + wan_layer = self.video_model.wan_model.blocks[layer_idx] + with torch.amp.autocast("cuda", dtype=torch.float32): + return (wan_layer.modulation.unsqueeze(0) + video_adaln_params).chunk(6, dim=2) + + def process_ffn(self, video_tokens: torch.Tensor, video_adaln_modulation: tuple, layer_idx: int) -> torch.Tensor: + wan_layer = self.video_model.wan_model.blocks[layer_idx] + v_mod = video_adaln_modulation + ffn_input = wan_layer.norm2(video_tokens).float() * (1 + v_mod[4].squeeze(2)) + v_mod[3].squeeze(2) + ffn_out = wan_layer.ffn(ffn_input) + with torch.amp.autocast("cuda", dtype=torch.float32): + return video_tokens + ffn_out * v_mod[5].squeeze(2) + + def apply_output_head(self, video_tokens: torch.Tensor, video_time_emb: torch.Tensor) -> torch.Tensor: + x = self.video_model.wan_model.head(video_tokens, video_time_emb) + x = self.video_model.wan_model.unpatchify(x, self.grid_sizes) + return torch.stack([u.float() for u in x], dim=0) + + +class UndModule(nn.Module): + def __init__(self, vlm_model, und_expert, config, dtype, device, image_context_adapter=None): + super().__init__() + self.vlm_model = vlm_model + self.und_expert = und_expert + self.config = config + self.dtype = dtype + self.device = device + self.image_context_adapter = image_context_adapter + + def _parse_vision_outputs(self, vision_outputs): + if hasattr(vision_outputs, "pooler_output"): + image_embeds = vision_outputs.pooler_output + deepstack_image_embeds = vision_outputs.get("hidden_states", None) if hasattr(vision_outputs, "get") else getattr(vision_outputs, "hidden_states", None) + elif isinstance(vision_outputs, tuple): + image_embeds = vision_outputs[0] + deepstack_image_embeds = vision_outputs[1] if len(vision_outputs) > 1 else None + else: + image_embeds = vision_outputs + deepstack_image_embeds = None + + if torch.is_tensor(image_embeds): + return image_embeds.to(self.device, self.dtype), deepstack_image_embeds + if isinstance(image_embeds, (list, tuple)): + return torch.cat(list(image_embeds), dim=0).to(self.device, self.dtype), deepstack_image_embeds + raise TypeError(f"Unsupported image feature output type: {type(image_embeds)}") + + def _process_vlm_inputs_to_tokens(self, vlm_inputs, batch: int): + if isinstance(vlm_inputs, list): + input_ids_batch = torch.cat([item["input_ids"] for item in vlm_inputs], dim=0).to(self.device) + attention_mask_batch = torch.cat([item["attention_mask"] for item in vlm_inputs], dim=0).to(self.device) + pixel_values_batch = torch.cat([item["pixel_values"] for item in vlm_inputs], dim=0).to(self.device) + image_grid_thw_batch = torch.cat([item["image_grid_thw"] for item in vlm_inputs], dim=0).to(self.device) + else: + input_ids_batch = vlm_inputs["input_ids"].to(self.device) + attention_mask_batch = vlm_inputs["attention_mask"].to(self.device) + pixel_values_batch = vlm_inputs["pixel_values"].to(self.device) + image_grid_thw_batch = vlm_inputs["image_grid_thw"].to(self.device) + + inputs_embeds = self.vlm_model.get_input_embeddings()(input_ids_batch) + vision_outputs = self.vlm_model.get_image_features(pixel_values_batch, image_grid_thw_batch) + image_embeds, deepstack_image_embeds = self._parse_vision_outputs(vision_outputs) + image_mask, _ = self.vlm_model.model.get_placeholder_mask(input_ids_batch, inputs_embeds=inputs_embeds, image_features=image_embeds) + inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) + visual_pos_masks = image_mask[..., 0] + position_ids, _ = self.vlm_model.model.get_rope_index( + input_ids=input_ids_batch, + image_grid_thw=image_grid_thw_batch, + video_grid_thw=None, + attention_mask=attention_mask_batch, + ) + return inputs_embeds, attention_mask_batch, visual_pos_masks, deepstack_image_embeds, position_ids + + def extract_und_features(self, vlm_inputs) -> torch.Tensor: + batch = len(vlm_inputs) if isinstance(vlm_inputs, list) else vlm_inputs["input_ids"].shape[0] + inputs_embeds, attention_mask, visual_pos_masks, deepstack_image_embeds, position_ids = self._process_vlm_inputs_to_tokens(vlm_inputs, batch) + kwargs = { + "inputs_embeds": inputs_embeds, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": None, + "use_cache": False, + "output_attentions": False, + "output_hidden_states": True, + "return_dict": True, + } + if visual_pos_masks is not None: + kwargs["visual_pos_masks"] = visual_pos_masks + if deepstack_image_embeds is not None: + kwargs["deepstack_visual_embeds"] = deepstack_image_embeds + with torch.no_grad(): + vlm_output = self.vlm_model.model.language_model(**kwargs) + return self.und_expert.vlm_adapter(vlm_output.hidden_states[-1]) + + def extract_image_context(self, vlm_inputs) -> torch.Tensor | None: + if self.image_context_adapter is None: + return None + + if isinstance(vlm_inputs, list): + pixel_values = torch.cat([item["pixel_values"] for item in vlm_inputs], dim=0).to(self.device) + image_grid_thw = torch.cat([item["image_grid_thw"] for item in vlm_inputs], dim=0).to(self.device) + else: + pixel_values = vlm_inputs["pixel_values"].to(self.device) + image_grid_thw = vlm_inputs["image_grid_thw"].to(self.device) + + with torch.no_grad(): + vision_outputs = self.vlm_model.get_image_features(pixel_values, image_grid_thw) + image_embeds, _ = self._parse_vision_outputs(vision_outputs) + return self.image_context_adapter(image_embeds) + + def process_ffn(self, und_tokens: torch.Tensor, layer_idx: int) -> torch.Tensor: + block = self.und_expert.blocks[layer_idx] + return und_tokens + block.ffn(block.norm2(und_tokens)) + + +class ActionModule(nn.Module): + def __init__(self, action_expert: ActionExpert, config, video_model, vlm_model, dtype, device): + super().__init__() + self.action_expert = action_expert + self.config = config + self.video_model = video_model + self.vlm_model = vlm_model + self.dtype = dtype + self.device = device + + def get_time_embedding(self, t: torch.Tensor, seq_len: int) -> tuple[torch.Tensor, torch.Tensor]: + if t.dim() == 1: + t = t.unsqueeze(1).expand(t.size(0), seq_len) + with torch.amp.autocast("cuda", dtype=torch.float32): + bt = t.size(0) + t_flat = t.flatten() + a_e = self.action_expert.time_embedding(sinusoidal_embedding_1d(self.action_expert.freq_dim, t_flat).unflatten(0, (bt, seq_len)).float()) + a_e0 = self.action_expert.time_projection(a_e).unflatten(2, (6, self.config.action_expert_dim)) + return a_e, a_e0 + + def compute_adaln_modulation(self, action_adaln_params: torch.Tensor, layer_idx: int) -> tuple: + action_layer = self.action_expert.blocks[layer_idx] + with torch.amp.autocast("cuda", dtype=torch.float32): + return (action_layer.modulation.unsqueeze(0) + action_adaln_params).chunk(6, dim=2) + + def process_ffn(self, action_tokens: torch.Tensor, action_adaln_modulation: tuple, layer_idx: int) -> torch.Tensor: + action_block = self.action_expert.blocks[layer_idx] + a_mod = action_adaln_modulation + ffn_input = action_block.norm2(action_tokens).float() * (1 + a_mod[4].squeeze(2)) + a_mod[3].squeeze(2) + ffn_out = action_block.ffn(ffn_input) + with torch.amp.autocast("cuda", dtype=torch.float32): + return action_tokens + ffn_out * a_mod[5].squeeze(2) + + +class Motus(nn.Module): + def __init__(self, config: MotusConfig): + super().__init__() + self.config = config + self.dtype = torch.bfloat16 + load_backbones = True if config.load_pretrained_backbones is None else bool(config.load_pretrained_backbones) + + if load_backbones: + self.video_model = WanVideoModel.from_pretrained( + checkpoint_path=config.wan_checkpoint_path, + vae_path=config.vae_path, + config_path=config.wan_config_path, + precision=config.video_precision, + ) + else: + self.video_model = WanVideoModel.from_config( + config_path=config.wan_config_path, + vae_path=config.vae_path, + device="cuda", + precision=config.video_precision, + ) + + if load_backbones: + self.vlm_model = Qwen3VLForConditionalGeneration.from_pretrained( + config.vlm_checkpoint_path, + dtype=self.dtype, + device_map="cuda", + trust_remote_code=True, + ) + else: + vlm_cfg = AutoConfig.from_pretrained(config.vlm_checkpoint_path, trust_remote_code=True) + self.vlm_model = Qwen3VLForConditionalGeneration._from_config(vlm_cfg, torch_dtype=self.dtype) + self.vlm_model.to(device="cuda", dtype=self.dtype) + + for param in self.vlm_model.parameters(): + param.requires_grad = False + + wan_dim = getattr(self.video_model.wan_model.config, "dim", 3072) + wan_num_heads = getattr(self.video_model.wan_model.config, "num_heads", 24) + wan_head_dim = wan_dim // wan_num_heads + vlm_dim = self.vlm_model.config.text_config.hidden_size + vlm_num_heads = self.vlm_model.config.text_config.num_attention_heads + vlm_num_kv_heads = getattr(self.vlm_model.config.text_config, "num_key_value_heads", vlm_num_heads) + vlm_num_hidden_layers = self.vlm_model.config.text_config.num_hidden_layers + + wan_config = {"dim": wan_dim, "num_heads": wan_num_heads, "head_dim": wan_head_dim} + vlm_config = { + "hidden_size": vlm_dim, + "num_attention_heads": vlm_num_heads, + "num_key_value_heads": vlm_num_kv_heads, + "head_dim": vlm_dim // vlm_num_heads, + "num_hidden_layers": vlm_num_hidden_layers, + } + + action_chunk_size_for_expert = config.action_chunk_size if config.training_mode == "pretrain" else config.action_chunk_size + 1 + num_registers = 0 if config.training_mode == "pretrain" else 4 + action_config = ActionExpertConfig( + dim=config.action_expert_dim, + ffn_dim=config.action_expert_dim * config.action_expert_ffn_dim_multiplier, + num_layers=config.num_layers, + state_dim=config.action_state_dim, + action_dim=config.action_dim, + chunk_size=action_chunk_size_for_expert, + num_registers=num_registers, + video_feature_dim=wan_dim, + causal=False, + eps=config.action_expert_norm_eps, + training_mode=config.training_mode, + ) + self.action_expert = ActionExpert(action_config, wan_config) + + und_config = UndExpertConfig( + dim=config.und_expert_hidden_size, + ffn_dim=config.und_expert_hidden_size * config.und_expert_ffn_dim_multiplier, + num_layers=config.num_layers, + vlm_input_dim=config.vlm_adapter_input_dim, + vlm_projector_type=config.vlm_adapter_projector_type, + eps=config.und_expert_norm_eps, + ) + self.und_expert = UndExpert(und_config, wan_config, vlm_config) + self.image_context_adapter = nn.Sequential( + nn.Linear(vlm_dim, wan_dim), + nn.GELU(approximate="tanh"), + nn.Linear(wan_dim, wan_dim), + ) + + self.device = next(self.video_model.parameters()).device + self.action_expert.to(device=self.device, dtype=self.dtype) + self.und_expert.to(device=self.device, dtype=self.dtype) + self.image_context_adapter.to(device=self.device, dtype=self.dtype) + self.action_expert.time_embedding.to(dtype=torch.float32) + self.action_expert.time_projection.to(dtype=torch.float32) + + lat_t = 1 + config.num_video_frames // 4 + lat_h = config.video_height // 32 + lat_w = config.video_width // 32 + self.grid_sizes = torch.tensor([lat_t, lat_h, lat_w], dtype=torch.long, device=self.device).unsqueeze(0).expand(config.batch_size, -1) + self.video_module = VideoModule(self.video_model, self.dtype, self.device, self.grid_sizes) + self.und_module = UndModule(self.vlm_model, self.und_expert, self.config, self.dtype, self.device, image_context_adapter=self.image_context_adapter) + self.action_module = ActionModule(self.action_expert, self.config, self.video_model, self.vlm_model, self.dtype, self.device) + + def load_checkpoint(self, path: str, strict: bool = True) -> Dict: + checkpoint_path = Path(path) + if checkpoint_path.is_dir(): + checkpoint_file = checkpoint_path / "mp_rank_00_model_states.pt" + if not checkpoint_file.exists(): + raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_file}") + path = str(checkpoint_file) + checkpoint = torch.load(path, map_location="cpu") + state_dict = checkpoint["module"] + self.load_state_dict(state_dict, strict=strict) + return {key: value for key, value in checkpoint.items() if key not in ["module", "config"]} diff --git a/lightx2v/models/networks/motus/image_utils.py b/lightx2v/models/networks/motus/image_utils.py new file mode 100644 index 000000000..f4f0ab1d1 --- /dev/null +++ b/lightx2v/models/networks/motus/image_utils.py @@ -0,0 +1,19 @@ +import cv2 +import numpy as np + + +def resize_with_padding(frame: np.ndarray, target_size: tuple[int, int]) -> np.ndarray: + target_height, target_width = target_size + original_height, original_width = frame.shape[:2] + + scale = min(target_height / original_height, target_width / original_width) + new_height = int(original_height * scale) + new_width = int(original_width * scale) + + resized_frame = cv2.resize(frame, (new_width, new_height)) + padded_frame = np.zeros((target_height, target_width, frame.shape[2]), dtype=frame.dtype) + + y_offset = (target_height - new_height) // 2 + x_offset = (target_width - new_width) // 2 + padded_frame[y_offset : y_offset + new_height, x_offset : x_offset + new_width] = resized_frame + return padded_frame diff --git a/lightx2v/models/networks/motus/infer/__init__.py b/lightx2v/models/networks/motus/infer/__init__.py new file mode 100644 index 000000000..9279bcbda --- /dev/null +++ b/lightx2v/models/networks/motus/infer/__init__.py @@ -0,0 +1,5 @@ +from .post_infer import MotusPostInfer +from .pre_infer import MotusPreInfer +from .transformer_infer import MotusTransformerInfer + +__all__ = ["MotusPreInfer", "MotusTransformerInfer", "MotusPostInfer"] diff --git a/lightx2v/models/networks/motus/infer/module_io.py b/lightx2v/models/networks/motus/infer/module_io.py new file mode 100644 index 000000000..6f2df77ac --- /dev/null +++ b/lightx2v/models/networks/motus/infer/module_io.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import Any + +import torch + + +@dataclass +class MotusPreInferModuleOutput: + first_frame: torch.Tensor + state: torch.Tensor + instruction: str + t5_embeddings: list[torch.Tensor] + vlm_inputs: list[dict[str, Any]] + processed_t5_context: torch.Tensor + image_context: torch.Tensor | None + und_tokens: torch.Tensor + condition_frame_latent: torch.Tensor + grid_sizes: torch.Tensor + + +@dataclass +class MotusPostInferModuleOutput: + pred_frames: torch.Tensor + pred_actions: torch.Tensor diff --git a/lightx2v/models/networks/motus/infer/post_infer.py b/lightx2v/models/networks/motus/infer/post_infer.py new file mode 100644 index 000000000..01efc9d94 --- /dev/null +++ b/lightx2v/models/networks/motus/infer/post_infer.py @@ -0,0 +1,20 @@ +import torch + +from .module_io import MotusPostInferModuleOutput + + +class MotusPostInfer: + def __init__(self, adapter, config): + self.adapter = adapter + self.config = config + self.scheduler = None + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + @torch.no_grad() + def infer(self, video_latents: torch.Tensor, action_latents: torch.Tensor): + decoded_frames = self.adapter.model.video_model.decode_video(video_latents) + pred_frames = ((decoded_frames[:, :, 1:] + 1.0) / 2.0).clamp(0, 1).float() + pred_actions = self.adapter.denormalize_actions(action_latents.float()) + return MotusPostInferModuleOutput(pred_frames=pred_frames, pred_actions=pred_actions) diff --git a/lightx2v/models/networks/motus/infer/pre_infer.py b/lightx2v/models/networks/motus/infer/pre_infer.py new file mode 100644 index 000000000..ab5cfd1c0 --- /dev/null +++ b/lightx2v/models/networks/motus/infer/pre_infer.py @@ -0,0 +1,49 @@ +import torch + +from .module_io import MotusPreInferModuleOutput + + +class MotusPreInfer: + def __init__(self, adapter, config): + self.adapter = adapter + self.config = config + self.scheduler = None + + def set_scheduler(self, scheduler): + self.scheduler = scheduler + + @torch.no_grad() + def infer(self, image_path: str, prompt: str, state_value, seed: int | None = None): + if self.scheduler is None: + raise RuntimeError("MotusPreInfer requires a scheduler before infer().") + + first_frame = self.adapter.prepare_frame(image_path) + state = self.adapter.prepare_state(state_value) + instruction = self.adapter.build_instruction(prompt) + t5_embeddings = self.adapter.build_t5_embeddings(instruction) + vlm_inputs = [self.adapter.build_vlm_inputs(instruction, first_frame)] + condition_frame_latent = self.adapter.encode_condition_frame(first_frame) + processed_t5_context = self.adapter.model.video_module.preprocess_t5_embeddings(t5_embeddings) + und_tokens = self.adapter.model.und_module.extract_und_features(vlm_inputs) + image_context = self.adapter.model.und_module.extract_image_context(vlm_inputs) + + self.scheduler.prepare( + seed=seed, + condition_frame_latent=condition_frame_latent, + action_shape=(state.shape[0], self.adapter.model.config.action_chunk_size, self.adapter.model.config.action_dim), + dtype=self.adapter.model.dtype, + device=self.adapter.device, + ) + + return MotusPreInferModuleOutput( + first_frame=first_frame, + state=state, + instruction=instruction, + t5_embeddings=t5_embeddings, + vlm_inputs=vlm_inputs, + processed_t5_context=processed_t5_context, + image_context=image_context, + und_tokens=und_tokens, + condition_frame_latent=condition_frame_latent, + grid_sizes=self.adapter.model.grid_sizes[: state.shape[0]], + ) diff --git a/lightx2v/models/networks/motus/infer/transformer_infer.py b/lightx2v/models/networks/motus/infer/transformer_infer.py new file mode 100644 index 000000000..ada74ba79 --- /dev/null +++ b/lightx2v/models/networks/motus/infer/transformer_infer.py @@ -0,0 +1,130 @@ +import torch + +from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer + +from ..ops import RegistryAttention + + +class MotusTransformerInfer(BaseTransformerInfer): + def __init__(self, adapter, config): + self.adapter = adapter + self.config = config + self.self_attn_1_type = config.get("self_attn_1_type", config.get("attention_type", "flash_attn2")) + self.self_attn_2_type = config.get("self_attn_2_type", config.get("attention_type", "flash_attn2")) + self.cross_attn_1_type = config.get("cross_attn_1_type", config.get("attention_type", "flash_attn2")) + self.self_attn = RegistryAttention(self.self_attn_1_type) + self.joint_self_attn = RegistryAttention(self.self_attn_2_type) + self.cross_attn = RegistryAttention(self.cross_attn_1_type) + + def _joint_attention(self, pre_infer_out, video_tokens, action_tokens, und_tokens, video_adaln_modulation, action_adaln_modulation, layer_idx): + model = self.adapter.model + wan_layer = model.video_module.video_model.wan_model.blocks[layer_idx] + action_block = model.action_expert.blocks[layer_idx] + und_block = model.und_expert.blocks[layer_idx] + + v_mod = video_adaln_modulation + a_mod = action_adaln_modulation + norm_video = wan_layer.norm1(video_tokens).float() * (1 + v_mod[1].squeeze(2)) + v_mod[0].squeeze(2) + norm_action = action_block.norm1(action_tokens).float() * (1 + a_mod[1].squeeze(2)) + a_mod[0].squeeze(2) + norm_und = und_block.norm1(und_tokens) + + batch, video_len, video_dim = norm_video.shape + action_len = norm_action.shape[1] + und_len = norm_und.shape[1] + num_heads = model.video_model.wan_model.num_heads + head_dim = video_dim // num_heads + + video_q = wan_layer.self_attn.norm_q(wan_layer.self_attn.q(norm_video)).view(batch, video_len, num_heads, head_dim) + video_k = wan_layer.self_attn.norm_k(wan_layer.self_attn.k(norm_video)).view(batch, video_len, num_heads, head_dim) + video_v = wan_layer.self_attn.v(norm_video).view(batch, video_len, num_heads, head_dim) + freqs = self.adapter.get_wan_freqs() + video_q = self.adapter.rope_apply(video_q, pre_infer_out.grid_sizes, freqs) + video_k = self.adapter.rope_apply(video_k, pre_infer_out.grid_sizes, freqs) + + action_q, action_k, action_v = action_block.wan_action_qkv_mm(norm_action) + action_q = action_block.wan_action_norm_q(action_q.flatten(-2)).view(batch, action_len, num_heads, head_dim) + action_k = action_block.wan_action_norm_k(action_k.flatten(-2)).view(batch, action_len, num_heads, head_dim) + + und_q, und_k, und_v = und_block.wan_und_qkv_mm(norm_und) + und_q = und_block.wan_und_norm_q(und_q.flatten(-2)).view(batch, und_len, num_heads, head_dim) + und_k = und_block.wan_und_norm_k(und_k.flatten(-2)).view(batch, und_len, num_heads, head_dim) + + q_all = torch.cat([video_q, action_q, und_q], dim=1) + k_all = torch.cat([video_k, action_k, und_k], dim=1) + v_all = torch.cat([video_v, action_v, und_v], dim=1) + attn_out = self.joint_self_attn(q_all, k_all, v_all) + + video_out = wan_layer.self_attn.o(attn_out[:, :video_len, :]) + action_out = action_block.wan_action_o(attn_out[:, video_len : video_len + action_len, :]) + und_out = und_block.wan_und_o(attn_out[:, video_len + action_len :, :]) + + video_tokens = video_tokens + video_out * v_mod[2].squeeze(2) + action_tokens = action_tokens + action_out * a_mod[2].squeeze(2) + und_tokens = und_tokens + und_out + return video_tokens, action_tokens, und_tokens + + def _cross_attention(self, video_tokens, processed_t5_context, layer_idx): + wan_layer = self.adapter.model.video_module.video_model.wan_model.blocks[layer_idx] + batch, q_len, dim = video_tokens.shape + ctx_len = processed_t5_context.shape[1] + num_heads = wan_layer.cross_attn.num_heads + head_dim = dim // num_heads + + norm_video = wan_layer.norm3(video_tokens) + q = wan_layer.cross_attn.norm_q(wan_layer.cross_attn.q(norm_video)).view(batch, q_len, num_heads, head_dim) + k = wan_layer.cross_attn.norm_k(wan_layer.cross_attn.k(processed_t5_context)).view(batch, ctx_len, num_heads, head_dim) + v = wan_layer.cross_attn.v(processed_t5_context).view(batch, ctx_len, num_heads, head_dim) + return video_tokens + wan_layer.cross_attn.o(self.cross_attn(q, k, v)) + + @torch.no_grad() + def infer(self, weights, pre_infer_out): + model = self.adapter.model + scheduler = self.scheduler + processed_t5_context = pre_infer_out.processed_t5_context + image_context = pre_infer_out.image_context + und_tokens_base = pre_infer_out.und_tokens + + for step_index, t, t_next, dt in scheduler.iter_steps(): + scheduler.step_pre(step_index) + video_tokens = model.video_module.prepare_input(scheduler.video_latents.to(model.dtype)) + state_tokens = pre_infer_out.state.unsqueeze(1).to(model.dtype) + # in case for the registers is set to 0 + registers = model.action_expert.registers + if registers is not None: + registers = registers.expand(state_tokens.shape[0], -1, -1) + action_tokens = model.action_expert.input_encoder(state_tokens, scheduler.action_latents, registers) + und_tokens = und_tokens_base.clone() + + video_t_scaled = (t * 1000).expand(state_tokens.shape[0]).to(model.dtype) + action_t_scaled = (t * 1000).expand(state_tokens.shape[0]).to(model.dtype) + + with torch.autocast(device_type="cuda", dtype=model.video_model.precision): + video_head_time_emb, video_adaln_params = model.video_module.get_time_embedding(video_t_scaled, video_tokens.shape[1]) + action_head_time_emb, action_adaln_params = model.action_module.get_time_embedding(action_t_scaled, action_tokens.shape[1]) + + for layer_idx in range(model.config.num_layers): + video_adaln_modulation = model.video_module.compute_adaln_modulation(video_adaln_params, layer_idx) + action_adaln_modulation = model.action_module.compute_adaln_modulation(action_adaln_params, layer_idx) + video_tokens, action_tokens, und_tokens = self._joint_attention( + pre_infer_out, + video_tokens, + action_tokens, + und_tokens, + video_adaln_modulation, + action_adaln_modulation, + layer_idx, + ) + video_tokens = self._cross_attention(video_tokens, processed_t5_context, layer_idx) + video_tokens = model.video_module.process_ffn(video_tokens, video_adaln_modulation, layer_idx) + action_tokens = model.action_module.process_ffn(action_tokens, action_adaln_modulation, layer_idx) + und_tokens = model.und_module.process_ffn(und_tokens, layer_idx) + + video_velocity = model.video_module.apply_output_head(video_tokens, video_head_time_emb) + action_pred_full = model.action_expert.decoder(action_tokens, action_head_time_emb) + # in case for the registers is set to 0 + num_regs = model.action_expert.config.num_registers + action_velocity = action_pred_full[:, 1:-num_regs, :] if num_regs > 0 else action_pred_full[:, 1:, :] + + scheduler.step(video_velocity=video_velocity, action_velocity=action_velocity, dt=dt, condition_frame_latent=pre_infer_out.condition_frame_latent) + + return scheduler.video_latents, scheduler.action_latents diff --git a/lightx2v/models/networks/motus/model.py b/lightx2v/models/networks/motus/model.py new file mode 100644 index 000000000..c39098e9f --- /dev/null +++ b/lightx2v/models/networks/motus/model.py @@ -0,0 +1,311 @@ +import inspect +import json +import os +from pathlib import Path +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from loguru import logger +from transformers import AutoProcessor + +from lightx2v.models.networks.motus.core import Motus, MotusConfig +from lightx2v.models.networks.motus.image_utils import resize_with_padding +from lightx2v.models.networks.motus.infer.post_infer import MotusPostInfer +from lightx2v.models.networks.motus.infer.pre_infer import MotusPreInfer +from lightx2v.models.networks.motus.infer.transformer_infer import MotusTransformerInfer +from lightx2v.models.networks.motus.ops import LinearWithMM, TripleQKVProjector +from lightx2v.models.networks.motus.primitives import rope_apply +from lightx2v.models.networks.motus.t5 import T5EncoderModel +from lightx2v.models.schedulers.motus.scheduler import MotusScheduler + + +class MotusModel: + """Thin LightX2V wrapper over Motus native inference.""" + + def __init__(self, config, device): + self.config = config + self.device = device + self.motus_root = Path(config.get("model_path", "")).expanduser().resolve() + if not self.motus_root.exists(): + raise FileNotFoundError(f"Motus root not found: {self.motus_root}") + + self._motus_cls = Motus + self._motus_config_cls = MotusConfig + self._resize_with_padding = resize_with_padding + self._rope_apply = rope_apply + self._t5_encoder_cls = T5EncoderModel + + self.model = self._load_model().eval() + self.t5_encoder = self._load_t5_encoder() + self.vlm_processor = self._load_vlm_processor() + self._load_normalization_stats() + self._build_native_stack() + + def _build_native_stack(self): + self.scheduler = MotusScheduler(self.config) + self.pre_infer = MotusPreInfer(self, self.config) + self.transformer_infer = MotusTransformerInfer(self, self.config) + self.post_infer = MotusPostInfer(self, self.config) + + self.pre_infer.set_scheduler(self.scheduler) + self.transformer_infer.set_scheduler(self.scheduler) + self.post_infer.set_scheduler(self.scheduler) + + def _build_model_config(self): + return self._motus_config_cls( + wan_checkpoint_path=self.config["wan_path"], + vae_path=os.path.join(self.config["wan_path"], "Wan2.2_VAE.pth"), + wan_config_path=self.config["wan_path"], + video_precision=self.config.get("video_precision", "bfloat16"), + vlm_checkpoint_path=self.config["vlm_path"], + und_expert_hidden_size=self.config.get("und_expert_hidden_size", 512), + und_expert_ffn_dim_multiplier=self.config.get("und_expert_ffn_dim_multiplier", 4), + und_expert_norm_eps=self.config.get("und_expert_norm_eps", 1e-5), + und_layers_to_extract=self.config.get("und_layers_to_extract"), + vlm_adapter_input_dim=self.config.get("vlm_adapter_input_dim", 2048), + vlm_adapter_projector_type=self.config.get("vlm_adapter_projector_type", "mlp3x_silu"), + num_layers=self.config.get("num_layers", 30), + action_state_dim=self.config.get("action_state_dim", 14), + action_dim=self.config.get("action_dim", 14), + action_expert_dim=self.config.get("action_expert_dim", 1024), + action_expert_ffn_dim_multiplier=self.config.get("action_expert_ffn_dim_multiplier", 4), + action_expert_norm_eps=self.config.get("action_expert_norm_eps", 1e-6), + global_downsample_rate=self.config.get("global_downsample_rate", 3), + video_action_freq_ratio=self.config.get("video_action_freq_ratio", 2), + num_video_frames=self.config.get("num_video_frames", 8), + video_height=self.config.get("video_height", 384), + video_width=self.config.get("video_width", 320), + batch_size=1, + training_mode=self.config.get("training_mode", "finetune"), + load_pretrained_backbones=self.config.get("load_pretrained_backbones", False), + ) + + def _load_model(self): + logger.info("Loading Motus model") + model = self._motus_cls(self._build_model_config()) + self._patch_qwen3_vl_rope_index(model) + model.to(self.device) + model.load_checkpoint(self.config["checkpoint_path"], strict=False) + self._apply_lightx2v_patches(model) + return model + + def _load_t5_encoder(self): + return self._t5_encoder_cls( + text_len=512, + dtype=torch.bfloat16, + device=str(self.device), + checkpoint_path=os.path.join(self.config["wan_path"], "models_t5_umt5-xxl-enc-bf16.pth"), + tokenizer_path=os.path.join(self.config["wan_path"], "google", "umt5-xxl"), + ) + + def _load_vlm_processor(self): + return AutoProcessor.from_pretrained(self.config["vlm_path"], trust_remote_code=True) + + def _patch_qwen3_vl_rope_index(self, root: Any): + visited = set() + + def walk(obj: Any): + obj_id = id(obj) + if obj is None or obj_id in visited: + return + visited.add(obj_id) + + method = getattr(obj, "get_rope_index", None) + if callable(method): + try: + signature = inspect.signature(method) + except (TypeError, ValueError): + signature = None + + if signature and "mm_token_type_ids" in signature.parameters: + + def wrapped_get_rope_index(*args, __orig=method, **kwargs): + if "mm_token_type_ids" not in kwargs: + input_ids = kwargs.get("input_ids") + if input_ids is None and args: + input_ids = args[0] + if torch.is_tensor(input_ids): + kwargs["mm_token_type_ids"] = torch.zeros_like(input_ids, dtype=torch.long) + return __orig(*args, **kwargs) + + setattr(obj, "get_rope_index", wrapped_get_rope_index) + + if isinstance(obj, torch.nn.Module): + for child in obj.children(): + walk(child) + + for attr in ("model", "language_model", "visual", "vlm", "backbone"): + child = getattr(obj, attr, None) + if child is not None and child is not obj: + walk(child) + + walk(root) + + def _load_normalization_stats(self): + stat_path = self.motus_root / "utils" / "stat.json" + if stat_path.exists(): + with open(stat_path, "r") as f: + stat_data = json.load(f) + stats = stat_data.get(self.config.get("stats_key", "robotwin2"), {}) + if stats: + self.action_min = torch.tensor(stats["min"], dtype=torch.float32, device=self.device) + self.action_max = torch.tensor(stats["max"], dtype=torch.float32, device=self.device) + self.action_range = self.action_max - self.action_min + return + + action_dim = self.config.get("action_dim", 14) + self.action_min = torch.zeros(action_dim, dtype=torch.float32, device=self.device) + self.action_max = torch.ones(action_dim, dtype=torch.float32, device=self.device) + self.action_range = torch.ones(action_dim, dtype=torch.float32, device=self.device) + + def _quant_flags(self): + quantized = bool(self.config.get("motus_quantized", self.config.get("dit_quantized", False))) + quant_scheme = self.config.get("motus_quant_scheme", self.config.get("dit_quant_scheme", "Default")) + return quantized, quant_scheme + + def _replace_linear_modules(self, module): + quantized, quant_scheme = self._quant_flags() + for name, child in list(module.named_children()): + if isinstance(child, nn.Linear): + setattr( + module, + name, + LinearWithMM.from_linear( + child, + quant_scheme=quant_scheme, + quantized=quantized, + config=self.config, + ), + ) + else: + self._replace_linear_modules(child) + + def _attach_qkv_projectors(self, model): + quantized, quant_scheme = self._quant_flags() + for block in model.action_expert.blocks: + block.wan_action_qkv_mm = TripleQKVProjector( + block.wan_action_qkv.detach(), + quant_scheme=quant_scheme, + quantized=quantized, + config=self.config, + ) + block.wan_action_o = LinearWithMM.from_linear( + block.wan_action_o, + quant_scheme=quant_scheme, + quantized=quantized, + config=self.config, + ) + + for block in model.und_expert.blocks: + block.wan_und_qkv_mm = TripleQKVProjector( + block.wan_und_qkv.detach(), + quant_scheme=quant_scheme, + quantized=quantized, + config=self.config, + ) + block.wan_und_o = LinearWithMM.from_linear( + block.wan_und_o, + quant_scheme=quant_scheme, + quantized=quantized, + config=self.config, + ) + + def _apply_lightx2v_patches(self, model): + self._replace_linear_modules(model.action_expert) + self._replace_linear_modules(model.und_expert) + self._attach_qkv_projectors(model) + + def denormalize_actions(self, actions: torch.Tensor) -> torch.Tensor: + shape = actions.shape + flat = actions.reshape(-1, shape[-1]) + restored = flat * self.action_range.unsqueeze(0) + self.action_min.unsqueeze(0) + return restored.reshape(shape) + + def rope_apply(self, q: torch.Tensor, grid_sizes: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + return self._rope_apply(q, grid_sizes, freqs) + + def get_wan_freqs(self) -> torch.Tensor: + freqs = self.model.video_model.wan_model.freqs + if freqs.device != self.device: + freqs = freqs.to(self.device) + return freqs + + def prepare_frame(self, image_path: str) -> torch.Tensor: + image = Image.open(image_path).convert("RGB") + image_np = np.asarray(image).astype(np.float32) / 255.0 + resized_np = self._resize_with_padding( + image_np, + (self.config.get("video_height", 384), self.config.get("video_width", 320)), + ) + if resized_np.dtype == np.uint8: + resized_np = resized_np.astype(np.float32) / 255.0 + return torch.from_numpy(resized_np).permute(2, 0, 1).unsqueeze(0).to(self.device) + + def prepare_state(self, state_value) -> torch.Tensor: + if isinstance(state_value, torch.Tensor): + state = state_value.float() + else: + state = torch.tensor(state_value, dtype=torch.float32) + if state.dim() == 1: + state = state.unsqueeze(0) + return state.to(self.device) + + def build_instruction(self, prompt: str) -> str: + prefix = self.config.get( + "scene_prefix", + "The whole scene is in a realistic, industrial art style with three views: " + "a fixed rear camera, a movable left arm camera, and a movable right arm camera. " + "The aloha robot is currently performing the following task: ", + ) + return f"{prefix}{prompt}" + + def build_t5_embeddings(self, instruction: str): + t5_out = self.t5_encoder([instruction], str(self.device)) + if isinstance(t5_out, torch.Tensor): + return [t5_out.squeeze(0)] if t5_out.dim() == 3 else [t5_out] + return t5_out + + def _tensor_to_pil(self, tensor: torch.Tensor) -> Image.Image: + tensor = tensor.float().clamp(0, 1) + np_img = (tensor.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + return Image.fromarray(np_img, mode="RGB") + + def build_vlm_inputs(self, instruction: str, first_frame: torch.Tensor): + image = self._tensor_to_pil(first_frame.squeeze(0)) + messages = [{"role": "user", "content": [{"type": "text", "text": instruction}, {"type": "image", "image": image}]}] + text = self.vlm_processor.apply_chat_template(messages, add_generation_prompt=False, tokenize=False) + encoded = self.vlm_processor(text=[text], images=[image], return_tensors="pt") + + vlm_inputs = {} + for key in ("input_ids", "attention_mask", "pixel_values", "image_grid_thw", "video_grid_thw", "second_per_grid_ts", "mm_token_type_ids"): + value = encoded.get(key) + if torch.is_tensor(value): + vlm_inputs[key] = value.to(self.device) + elif value is not None: + vlm_inputs[key] = value + + if "mm_token_type_ids" not in vlm_inputs and "input_ids" in vlm_inputs: + vlm_inputs["mm_token_type_ids"] = torch.zeros_like(vlm_inputs["input_ids"], dtype=torch.long) + return vlm_inputs + + @torch.no_grad() + def encode_condition_frame(self, first_frame: torch.Tensor): + first_frame_norm = (first_frame * 2.0 - 1.0).unsqueeze(2) + return self.model.video_model.encode_video(first_frame_norm.to(self.model.dtype)) + + @torch.no_grad() + def infer(self, image_path: str, prompt: str, state_value, num_inference_steps: int, seed: int | None = None): + self.scheduler.infer_steps = num_inference_steps + pre_infer_out = self.pre_infer.infer(image_path=image_path, prompt=prompt, state_value=state_value, seed=seed) + video_latents, action_latents = self.transformer_infer.infer(None, pre_infer_out) + post_infer_out = self.post_infer.infer(video_latents, action_latents) + + pred_frames = post_infer_out.pred_frames + if pred_frames.dim() == 5: + if pred_frames.shape[1] == 3: + pred_frames = pred_frames.permute(0, 2, 1, 3, 4) + pred_frames = pred_frames.squeeze(0) + return pred_frames, post_infer_out.pred_actions.squeeze(0) diff --git a/lightx2v/models/networks/motus/ops.py b/lightx2v/models/networks/motus/ops.py new file mode 100644 index 000000000..507e98033 --- /dev/null +++ b/lightx2v/models/networks/motus/ops.py @@ -0,0 +1,160 @@ +import copy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER + + +class LinearWithMM(nn.Module): + """nn.Linear-compatible module with optional LightX2V MM backend.""" + + def __init__(self, in_features, out_features, bias=True, quant_scheme="Default", quantized=False, config=None): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty(out_features, in_features)) + if bias: + self.bias = nn.Parameter(torch.empty(out_features)) + else: + self.register_parameter("bias", None) + self.quant_scheme = quant_scheme + self.quantized = quantized and quant_scheme != "Default" + self.config = copy.deepcopy(config or {}) + self.mm = None + + @classmethod + def from_linear(cls, linear: nn.Linear, quant_scheme="Default", quantized=False, config=None): + module = cls( + linear.in_features, + linear.out_features, + bias=linear.bias is not None, + quant_scheme=quant_scheme, + quantized=quantized, + config=config, + ) + with torch.no_grad(): + module.weight.copy_(linear.weight.detach()) + if linear.bias is not None: + module.bias.copy_(linear.bias.detach()) + module = module.to(device=linear.weight.device, dtype=linear.weight.dtype) + module._build_mm() + return module + + @classmethod + def from_tensor(cls, weight: torch.Tensor, bias: torch.Tensor | None = None, quant_scheme="Default", quantized=False, config=None): + out_features, in_features = weight.shape + module = cls( + in_features, + out_features, + bias=bias is not None, + quant_scheme=quant_scheme, + quantized=quantized, + config=config, + ) + with torch.no_grad(): + module.weight.copy_(weight.detach()) + if bias is not None: + module.bias.copy_(bias.detach()) + module = module.to(device=weight.device, dtype=weight.dtype) + module._build_mm() + return module + + def _build_mm(self): + scheme = self.quant_scheme if self.quantized else "Default" + self.mm = MM_WEIGHT_REGISTER[scheme]("__motus_weight__", "__motus_bias__" if self.bias is not None else None) + if hasattr(self.mm, "set_config"): + cfg = copy.deepcopy(self.config) + if self.quantized: + cfg["dit_quantized"] = True + cfg["dit_quant_scheme"] = self.quant_scheme + cfg.setdefault("weight_auto_quant", True) + self.mm.set_config(cfg) + weight_dict = {"__motus_weight__": self.weight.detach()} + if self.bias is not None: + weight_dict["__motus_bias__"] = self.bias.detach() + self.mm.load(weight_dict) + + def _mm_apply(self, x): + if self.mm is None: + self._build_mm() + x2d = x.reshape(-1, x.shape[-1]) + y2d = self.mm.apply(x2d.to(self.weight.dtype)) + if y2d.dtype != x.dtype: + y2d = y2d.to(x.dtype) + return y2d.reshape(*x.shape[:-1], self.out_features) + + def forward(self, x): + if not self.quantized: + return F.linear(x, self.weight, self.bias) + return self._mm_apply(x) + + +class TripleQKVProjector(nn.Module): + """Three-way linear projection for q/k/v from a packed tensor.""" + + def __init__(self, packed_qkv: torch.Tensor, quant_scheme="Default", quantized=False, config=None): + super().__init__() + assert packed_qkv.dim() == 4 + self.num_heads = packed_qkv.shape[1] + self.in_features = packed_qkv.shape[2] + self.head_dim = packed_qkv.shape[3] + self.out_features = self.num_heads * self.head_dim + + q_w = packed_qkv[0].permute(0, 2, 1).reshape(self.out_features, self.in_features).contiguous() + k_w = packed_qkv[1].permute(0, 2, 1).reshape(self.out_features, self.in_features).contiguous() + v_w = packed_qkv[2].permute(0, 2, 1).reshape(self.out_features, self.in_features).contiguous() + + self.q = LinearWithMM.from_tensor(q_w, None, quant_scheme=quant_scheme, quantized=quantized, config=config) + self.k = LinearWithMM.from_tensor(k_w, None, quant_scheme=quant_scheme, quantized=quantized, config=config) + self.v = LinearWithMM.from_tensor(v_w, None, quant_scheme=quant_scheme, quantized=quantized, config=config) + + def forward(self, x): + q = self.q(x).reshape(*x.shape[:-1], self.num_heads, self.head_dim) + k = self.k(x).reshape(*x.shape[:-1], self.num_heads, self.head_dim) + v = self.v(x).reshape(*x.shape[:-1], self.num_heads, self.head_dim) + return q, k, v + + +class RegistryAttention(nn.Module): + """LightX2V attention-kernel wrapper with Wan-style varlen arguments.""" + + def __init__(self, attn_type: str): + super().__init__() + self.attn_type = attn_type + self.kernel = ATTN_WEIGHT_REGISTER[attn_type]() + + def _build_cu_seqlens(self, batch: int, seq_len: int, device: torch.device): + return torch.arange(0, (batch + 1) * seq_len, seq_len, dtype=torch.int32, device=device) + + def _normalize_dtype(self, tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype in (torch.float16, torch.bfloat16): + return tensor + if tensor.device.type == "cuda": + return tensor.to(torch.bfloat16) + return tensor.to(torch.float32) + + def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False): + if q.dim() != 4 or k.dim() != 4 or v.dim() != 4: + raise ValueError("RegistryAttention expects q/k/v with shape [B, L, H, D].") + + q = self._normalize_dtype(q) + k = self._normalize_dtype(k) + v = self._normalize_dtype(v) + + batch, q_len = q.shape[:2] + kv_len = k.shape[1] + out = self.kernel.apply( + q=q, + k=k, + v=v, + causal=causal, + cu_seqlens_q=self._build_cu_seqlens(batch, q_len, q.device), + cu_seqlens_kv=self._build_cu_seqlens(batch, kv_len, k.device), + max_seqlen_q=q_len, + max_seqlen_kv=kv_len, + ) + if out.dim() == 2: + out = out.view(batch, q_len, -1) + return out diff --git a/lightx2v/models/networks/motus/primitives.py b/lightx2v/models/networks/motus/primitives.py new file mode 100644 index 000000000..5d30a1b26 --- /dev/null +++ b/lightx2v/models/networks/motus/primitives.py @@ -0,0 +1,72 @@ +from functools import lru_cache + +import torch +import torch.nn as nn + + +def sinusoidal_embedding_1d(dim, position): + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + return torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + + +@torch.amp.autocast("cuda", enabled=False) +def rope_apply(x: torch.Tensor, grid_sizes: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + batch, seq, heads, complex_twice_dim = x.shape + assert complex_twice_dim % 2 == 0 + complex_dim = complex_twice_dim // 2 + + c_f = complex_dim - 2 * (complex_dim // 3) + c_h = complex_dim // 3 + c_w = complex_dim // 3 + fpart, hpart, wpart = freqs.split([c_f, c_h, c_w], dim=1) + + x_c = torch.view_as_complex(x.to(torch.float64).reshape(batch, seq, heads, -1, 2)).contiguous() + y_c = x_c.clone() + gsz = grid_sizes.to(torch.long) + uniq, inv = torch.unique(gsz, dim=0, return_inverse=True) + + @lru_cache(maxsize=256) + def _make_freq_grid(f: int, h: int, w: int): + return ( + torch.cat( + [ + fpart[:f].view(f, 1, 1, -1).expand(f, h, w, -1), + hpart[:h].view(1, h, 1, -1).expand(f, h, w, -1), + wpart[:w].view(1, 1, w, -1).expand(f, h, w, -1), + ], + dim=-1, + ) + .reshape(f * h * w, 1, -1) + .contiguous() + ) + + for g_idx, (f, h, w) in enumerate(uniq.tolist()): + idx = (inv == g_idx).nonzero(as_tuple=False).squeeze(-1) + if idx.numel() == 0: + continue + seq_len = f * h * w + freq_grid = _make_freq_grid(f, h, w) + y_c[idx, :seq_len] = x_c[idx, :seq_len] * freq_grid + + return torch.view_as_real(y_c).reshape(batch, seq, heads, -1).float() + + +class WanRMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return (x.float() * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps)).type_as(x) * self.weight + + +class WanLayerNorm(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + return super().forward(x.float()).type_as(x) diff --git a/lightx2v/models/networks/motus/t5.py b/lightx2v/models/networks/motus/t5.py new file mode 100644 index 000000000..31c4e77ef --- /dev/null +++ b/lightx2v/models/networks/motus/t5.py @@ -0,0 +1,3 @@ +from .wan.t5 import T5EncoderModel + +__all__ = ["T5EncoderModel"] diff --git a/lightx2v/models/networks/motus/und_expert.py b/lightx2v/models/networks/motus/und_expert.py new file mode 100644 index 000000000..b95a82321 --- /dev/null +++ b/lightx2v/models/networks/motus/und_expert.py @@ -0,0 +1,56 @@ +import re +from dataclasses import dataclass + +import torch +import torch.nn as nn + +from .primitives import WanLayerNorm, WanRMSNorm + + +@dataclass +class UndExpertConfig: + dim: int = 512 + ffn_dim: int = 2048 + num_layers: int = 30 + vlm_input_dim: int = 2048 + vlm_projector_type: str = "mlp3x_silu" + eps: float = 1e-5 + + +def build_condition_adapter(projector_type, in_features, out_features): + if projector_type == "linear": + return nn.Linear(in_features, out_features) + mlp_silu_match = re.match(r"^mlp(\d+)x_silu$", projector_type) + if mlp_silu_match: + mlp_depth = int(mlp_silu_match.group(1)) + modules = [nn.Linear(in_features, out_features)] + for _ in range(1, mlp_depth): + modules.append(nn.SiLU()) + modules.append(nn.Linear(out_features, out_features)) + return nn.Sequential(*modules) + raise ValueError(f"Unknown projector type: {projector_type}") + + +class UndExpertBlock(nn.Module): + def __init__(self, config: UndExpertConfig, wan_config: dict): + super().__init__() + self.norm1 = WanLayerNorm(config.dim, eps=config.eps) + self.norm2 = WanLayerNorm(config.dim, eps=config.eps) + self.wan_num_heads = wan_config["num_heads"] + self.wan_head_dim = wan_config["head_dim"] + self.wan_dim = wan_config["dim"] + self.wan_und_qkv = nn.Parameter(torch.randn(3, self.wan_num_heads, config.dim, self.wan_head_dim) / (config.dim * self.wan_head_dim) ** 0.5) + self.wan_und_o = nn.Linear(self.wan_dim, config.dim, bias=False) + self.wan_und_norm_q = WanRMSNorm(self.wan_dim, eps=config.eps) + self.wan_und_norm_k = WanRMSNorm(self.wan_dim, eps=config.eps) + self.ffn = nn.Sequential(nn.Linear(config.dim, config.ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(config.ffn_dim, config.dim)) + + +class UndExpert(nn.Module): + def __init__(self, config: UndExpertConfig, wan_config: dict = None, vlm_config: dict = None): + super().__init__() + self.config = config + self.freq_dim = 256 + self.vlm_adapter = build_condition_adapter(config.vlm_projector_type, config.vlm_input_dim, config.dim) + block_cfg = wan_config or {"dim": 3072, "num_heads": 24, "head_dim": 128} + self.blocks = nn.ModuleList([UndExpertBlock(config, block_cfg) for _ in range(config.num_layers)]) diff --git a/lightx2v/models/networks/motus/wan/__init__.py b/lightx2v/models/networks/motus/wan/__init__.py new file mode 100644 index 000000000..9ce1f8e18 --- /dev/null +++ b/lightx2v/models/networks/motus/wan/__init__.py @@ -0,0 +1,13 @@ +from .attention import flash_attention +from .model import WanModel +from .t5 import T5EncoderModel +from .tokenizers import HuggingfaceTokenizer +from .vae2_2 import Wan2_2_VAE + +__all__ = [ + "WanModel", + "Wan2_2_VAE", + "T5EncoderModel", + "HuggingfaceTokenizer", + "flash_attention", +] diff --git a/lightx2v/models/networks/motus/wan/attention.py b/lightx2v/models/networks/motus/wan/attention.py new file mode 100644 index 000000000..cc0b2f681 --- /dev/null +++ b/lightx2v/models/networks/motus/wan/attention.py @@ -0,0 +1,98 @@ +import warnings + +import torch + +try: + import flash_attn_interface + + FLASH_ATTN_3_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_3_AVAILABLE = False + +try: + import flash_attn + + FLASH_ATTN_2_AVAILABLE = True +except ModuleNotFoundError: + FLASH_ATTN_2_AVAILABLE = False + + +def flash_attention( + q, + k, + v, + q_lens=None, + k_lens=None, + dropout_p=0.0, + softmax_scale=None, + q_scale=None, + causal=False, + window_size=(-1, -1), + deterministic=False, + dtype=torch.bfloat16, + version=None, +): + half_dtypes = (torch.float16, torch.bfloat16) + assert dtype in half_dtypes + assert q.device.type == "cuda" and q.size(-1) <= 256 + + batch, q_len, kv_len, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype + + def half(x): + return x if x.dtype in half_dtypes else x.to(dtype) + + if q_lens is None: + q = half(q.flatten(0, 1)) + q_lens = torch.tensor([q_len] * batch, dtype=torch.int32).to(device=q.device, non_blocking=True) + else: + q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)])) + + if k_lens is None: + k = half(k.flatten(0, 1)) + v = half(v.flatten(0, 1)) + k_lens = torch.tensor([kv_len] * batch, dtype=torch.int32).to(device=k.device, non_blocking=True) + else: + k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)])) + v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)])) + + q = q.to(v.dtype) + k = k.to(v.dtype) + if q_scale is not None: + q = q * q_scale + + if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE: + warnings.warn("Flash attention 3 is not available, using flash attention 2 instead.") + + if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE: + x = flash_attn_interface.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + seqused_q=None, + seqused_k=None, + max_seqlen_q=q_len, + max_seqlen_k=kv_len, + softmax_scale=softmax_scale, + causal=causal, + deterministic=deterministic, + )[0].unflatten(0, (batch, q_len)) + else: + assert FLASH_ATTN_2_AVAILABLE + x = flash_attn.flash_attn_varlen_func( + q=q, + k=k, + v=v, + cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32).to(q.device, non_blocking=True), + max_seqlen_q=q_len, + max_seqlen_k=kv_len, + dropout_p=dropout_p, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + deterministic=deterministic, + ).unflatten(0, (batch, q_len)) + + return x.type(out_dtype) diff --git a/lightx2v/models/networks/motus/wan/model.py b/lightx2v/models/networks/motus/wan/model.py new file mode 100644 index 000000000..122b46613 --- /dev/null +++ b/lightx2v/models/networks/motus/wan/model.py @@ -0,0 +1,613 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import math +from functools import lru_cache + +import torch +import torch.nn as nn +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin + +from .attention import flash_attention + +__all__ = ["WanModel"] + + +def sinusoidal_embedding_1d(dim, position): + # preprocess + assert dim % 2 == 0 + half = dim // 2 + position = position.type(torch.float64) + + # calculation + sinusoid = torch.outer(position, torch.pow(10000, -torch.arange(half).to(position).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + return x + + +@torch.amp.autocast("cuda", enabled=False) +def rope_params(max_seq_len, dim, theta=10000): + assert dim % 2 == 0 + freqs = torch.outer(torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim))) + freqs = torch.polar(torch.ones_like(freqs), freqs) + return freqs + + +@torch.amp.autocast("cuda", enabled=False) +def rope_apply(x: torch.Tensor, grid_sizes: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor: + B, T, N, CC = x.shape + assert CC % 2 == 0, "last dim must be 2C (real, imag)" + C = CC // 2 + + c_f = C - 2 * (C // 3) + c_h = C // 3 + c_w = C // 3 + fpart, hpart, wpart = freqs.split([c_f, c_h, c_w], dim=1) + + x_c = torch.view_as_complex(x.to(torch.float64).reshape(B, T, N, -1, 2)).contiguous() + + y_c = x_c.clone() + + gsz = grid_sizes.to(torch.long) + uniq, inv = torch.unique(gsz, dim=0, return_inverse=True) + + @lru_cache(maxsize=256) + def _make_freq_grid(f: int, h: int, w: int): + fi = torch.cat( + [ + fpart[:f].view(f, 1, 1, -1).expand(f, h, w, -1), # [f,h,w,c_f] + hpart[:h].view(1, h, 1, -1).expand(f, h, w, -1), # [f,h,w,c_h] + wpart[:w].view(1, 1, w, -1).expand(f, h, w, -1), # [f,h,w,c_w] + ], + dim=-1, + ).reshape(f * h * w, 1, -1) # [seq_len,1,C] + return fi.contiguous() + + for g_idx, (f, h, w) in enumerate(uniq.tolist()): + idx = (inv == g_idx).nonzero(as_tuple=False).squeeze(-1) + if idx.numel() == 0: + continue + seq_len = f * h * w + + freq_grid = _make_freq_grid(f, h, w) # [seq_len,1,C] + + y_c[idx, :seq_len] = x_c[idx, :seq_len] * freq_grid + + y = torch.view_as_real(y_c).reshape(B, T, N, -1).float() + # assert rope_apply_original(x, grid_sizes, freqs).allclose(y, atol=1e-5) + return y + + +@torch.amp.autocast("cuda", enabled=False) +def rope_apply_original(x, grid_sizes, freqs): + n, c = x.size(2), x.size(3) // 2 + + # split freqs + freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) + + # loop over samples + output = [] + for i, (f, h, w) in enumerate(grid_sizes.tolist()): + seq_len = f * h * w + + # precompute multipliers + x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)) + freqs_i = torch.cat( + [freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1), freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1), freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)], dim=-1 + ).reshape(seq_len, 1, -1) + + # apply rotary embedding + x_i = torch.view_as_real(x_i * freqs_i).flatten(2) + x_i = torch.cat([x_i, x[i, seq_len:]]) + + # append to collection + output.append(x_i) + return torch.stack(output).float() + + +class WanRMSNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return self._norm(x.float()).type_as(x) * self.weight + + def _norm(self, x): + return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps) + + +class WanLayerNorm(nn.LayerNorm): + def __init__(self, dim, eps=1e-6, elementwise_affine=False): + super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps) + + def forward(self, x): + r""" + Args: + x(Tensor): Shape [B, L, C] + """ + return super().forward(x.float()).type_as(x) + + +class WanSelfAttention(nn.Module): + def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6): + assert dim % num_heads == 0 + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.eps = eps + + # layers + self.q = nn.Linear(dim, dim) + self.k = nn.Linear(dim, dim) + self.v = nn.Linear(dim, dim) + self.o = nn.Linear(dim, dim) + self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity() + + def forward( + self, + x, + seq_lens, + grid_sizes, + freqs, + action_q: torch.Tensor = None, + action_k: torch.Tensor = None, + action_v: torch.Tensor = None, + und_q: torch.Tensor = None, + und_k: torch.Tensor = None, + und_v: torch.Tensor = None, + ): + r""" + Args: + x(Tensor): Shape [B, L, num_heads, C / num_heads] + seq_lens(Tensor): Shape [B] + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + action_q/k/v(Tensor, optional): Action expert Q/K/V for trimodal MoT + und_q/k/v(Tensor, optional): Understanding expert Q/K/V for trimodal MoT + """ + b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim + + # query, key, value function + def qkv_fn(x): + q = self.norm_q(self.q(x)).view(b, s, n, d) + k = self.norm_k(self.k(x)).view(b, s, n, d) + v = self.v(x).view(b, s, n, d) + return q, k, v + + q, k, v = qkv_fn(x) + + # Trimodal MoT branch: WAN + Action + Understanding + if action_q is not None or und_q is not None: + L_x = q.size(1) + + # Apply RoPE only to video tokens (q, k) + q_video_rope = rope_apply(q, grid_sizes, freqs) + k_video_rope = rope_apply(k, grid_sizes, freqs) + + # Prepare parts for concatenation + q_parts = [q_video_rope] + k_parts = [k_video_rope] + v_parts = [v] + + # Add action tokens if provided + if action_q is not None: + q_parts.append(action_q) + k_parts.append(action_k) + v_parts.append(action_v) + L_action = action_q.size(1) + else: + L_action = 0 + + # Add understanding tokens if provided + if und_q is not None: + q_parts.append(und_q) + k_parts.append(und_k) + v_parts.append(und_v) + L_und = und_q.size(1) + else: + L_und = 0 + + # Concatenate all modalities + q_cat = torch.cat(q_parts, dim=1) + k_cat = torch.cat(k_parts, dim=1) + v_cat = torch.cat(v_parts, dim=1) + + attn_out = flash_attention(q=q_cat, k=k_cat, v=v_cat, k_lens=seq_lens, window_size=self.window_size) + + # Split outputs back to respective modalities + x_out = attn_out[:, :L_x, :, :] + outputs = [x_out] + + start_idx = L_x + if action_q is not None: + action_out = attn_out[:, start_idx : start_idx + L_action, :, :] + outputs.append(action_out) + start_idx += L_action + else: + outputs.append(None) + + if und_q is not None: + und_out = attn_out[:, start_idx : start_idx + L_und, :, :] + outputs.append(und_out) + else: + outputs.append(None) + + # Project WAN branch; other branches returned in head shape for external projection + x_out = x_out.flatten(2) + x_out = self.o(x_out) + outputs[0] = x_out + + return tuple(outputs) + + # Standard branch (no MoT) + x = flash_attention(q=rope_apply(q, grid_sizes, freqs), k=rope_apply(k, grid_sizes, freqs), v=v, k_lens=seq_lens, window_size=self.window_size) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanCrossAttention(WanSelfAttention): + def forward(self, x, context, context_lens): + r""" + Args: + x(Tensor): Shape [B, L1, C] + context(Tensor): Shape [B, L2, C] + context_lens(Tensor): Shape [B] + """ + b, n, d = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.norm_q(self.q(x)).view(b, -1, n, d) + k = self.norm_k(self.k(context)).view(b, -1, n, d) + v = self.v(context).view(b, -1, n, d) + + # compute attention + x = flash_attention(q, k, v, k_lens=context_lens) + + # output + x = x.flatten(2) + x = self.o(x) + return x + + +class WanAttentionBlock(nn.Module): + def __init__(self, dim, ffn_dim, num_heads, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, eps=1e-6): + super().__init__() + self.dim = dim + self.ffn_dim = ffn_dim + self.num_heads = num_heads + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # layers + self.norm1 = WanLayerNorm(dim, eps) + self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps) + + self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() + self.cross_attn = WanCrossAttention(dim, num_heads, (-1, -1), qk_norm, eps) + self.norm2 = WanLayerNorm(dim, eps) + self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim)) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5) + + def forward( + self, + x, + e, + seq_lens, + grid_sizes, + freqs, + context, + context_lens, + ): + r""" + Args: + x(Tensor): Shape [B, L, C] + e(Tensor): Shape [B, L1, 6, C] + seq_lens(Tensor): Shape [B], length of each sequence in batch + grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W) + freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2) + assert e[0].dtype == torch.float32 + + # self-attention + y = self.self_attn(self.norm1(x).float() * (1 + e[1].squeeze(2)) + e[0].squeeze(2), seq_lens, grid_sizes, freqs) + with torch.amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[2].squeeze(2) + + # cross-attention & ffn function + def cross_attn_ffn(x, context, context_lens, e): + x = x + self.cross_attn(self.norm3(x), context, context_lens) + y = self.ffn(self.norm2(x).float() * (1 + e[4].squeeze(2)) + e[3].squeeze(2)) + with torch.amp.autocast("cuda", dtype=torch.float32): + x = x + y * e[5].squeeze(2) + return x + + x = cross_attn_ffn(x, context, context_lens, e) + return x + + +class Head(nn.Module): + def __init__(self, dim, out_dim, patch_size, eps=1e-6): + super().__init__() + self.dim = dim + self.out_dim = out_dim + self.patch_size = patch_size + self.eps = eps + + # layers + out_dim = math.prod(patch_size) * out_dim + self.norm = WanLayerNorm(dim, eps) + self.head = nn.Linear(dim, out_dim) + + # modulation + self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5) + + def forward(self, x, e): + r""" + Args: + x(Tensor): Shape [B, L1, C] + e(Tensor): Shape [B, L1, C] + """ + assert e.dtype == torch.float32 + with torch.amp.autocast("cuda", dtype=torch.float32): + e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2) + x = self.head(self.norm(x) * (1 + e[1].squeeze(2)) + e[0].squeeze(2)) + return x + + +class WanModel(ModelMixin, ConfigMixin): + r""" + Wan diffusion backbone supporting both text-to-video and image-to-video. + """ + + ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim", "window_size"] + _no_split_modules = ["WanAttentionBlock"] + + @register_to_config + def __init__( + self, + model_type="t2v", + patch_size=(1, 2, 2), + text_len=512, + in_dim=16, + dim=2048, + ffn_dim=8192, + freq_dim=256, + text_dim=4096, + out_dim=16, + num_heads=16, + num_layers=32, + window_size=(-1, -1), + qk_norm=True, + cross_attn_norm=True, + eps=1e-6, + ): + r""" + Initialize the diffusion model backbone. + + Args: + model_type (`str`, *optional*, defaults to 't2v'): + Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video) + patch_size (`tuple`, *optional*, defaults to (1, 2, 2)): + 3D patch dimensions for video embedding (t_patch, h_patch, w_patch) + text_len (`int`, *optional*, defaults to 512): + Fixed length for text embeddings + in_dim (`int`, *optional*, defaults to 16): + Input video channels (C_in) + dim (`int`, *optional*, defaults to 2048): + Hidden dimension of the transformer + ffn_dim (`int`, *optional*, defaults to 8192): + Intermediate dimension in feed-forward network + freq_dim (`int`, *optional*, defaults to 256): + Dimension for sinusoidal time embeddings + text_dim (`int`, *optional*, defaults to 4096): + Input dimension for text embeddings + out_dim (`int`, *optional*, defaults to 16): + Output video channels (C_out) + num_heads (`int`, *optional*, defaults to 16): + Number of attention heads + num_layers (`int`, *optional*, defaults to 32): + Number of transformer blocks + window_size (`tuple`, *optional*, defaults to (-1, -1)): + Window size for local attention (-1 indicates global attention) + qk_norm (`bool`, *optional*, defaults to True): + Enable query/key normalization + cross_attn_norm (`bool`, *optional*, defaults to False): + Enable cross-attention normalization + eps (`float`, *optional*, defaults to 1e-6): + Epsilon value for normalization layers + """ + + super().__init__() + + assert model_type in ["t2v", "i2v", "ti2v"] + self.model_type = model_type + + self.patch_size = patch_size + self.text_len = text_len + self.in_dim = in_dim + self.dim = dim + self.ffn_dim = ffn_dim + self.freq_dim = freq_dim + self.text_dim = text_dim + self.out_dim = out_dim + self.num_heads = num_heads + self.num_layers = num_layers + self.window_size = window_size + self.qk_norm = qk_norm + self.cross_attn_norm = cross_attn_norm + self.eps = eps + + # embeddings + self.patch_embedding = nn.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size) + self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim)) + + self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)) + self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)) + + # blocks + self.blocks = nn.ModuleList([WanAttentionBlock(dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps) for _ in range(num_layers)]) + + # head + self.head = Head(dim, out_dim, patch_size, eps) + + # buffers (don't use register_buffer otherwise dtype will be changed in to()) + assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0 + d = dim // num_heads + self.freqs = torch.cat([rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], dim=1) + + # initialize weights + self.init_weights() + + def forward( + self, + x, + t, + context, + seq_len, + y=None, + ): + r""" + Forward pass through the diffusion model + + Args: + x (List[Tensor]): + List of input video tensors, each with shape [C_in, F, H, W] + t (Tensor): + Diffusion timesteps tensor of shape [B] + context (List[Tensor]): + List of text embeddings each with shape [L, C] + seq_len (`int`): + Maximum sequence length for positional encoding + y (List[Tensor], *optional*): + Conditional video inputs for image-to-video mode, same shape as x + + Returns: + List[Tensor]: + List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8] + """ + if self.model_type == "i2v": + assert y is not None + # params + device = self.patch_embedding.weight.device + if self.freqs.device != device: + self.freqs = self.freqs.to(device) + + if y is not None: + x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] + + # embeddings + x = [self.patch_embedding(u.unsqueeze(0)) for u in x] + grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) + x = [u.flatten(2).transpose(1, 2) for u in x] + seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long) + assert seq_lens.max() <= seq_len + x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x]) + + # time embeddings + if t.dim() == 1: + t = t.expand(t.size(0), seq_len) + with torch.amp.autocast("cuda", dtype=torch.float32): + bt = t.size(0) + t = t.flatten() + e = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t).unflatten(0, (bt, seq_len)).float()) + e0 = self.time_projection(e).unflatten(2, (6, self.dim)) + assert e.dtype == torch.float32 and e0.dtype == torch.float32 + + # context + context_lens = None + context = self.text_embedding(torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])) + + # arguments + kwargs = dict(e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens) + + for block in self.blocks: + x = block(x, **kwargs) + + # head + x = self.head(x, e) + + # unpatchify + x = self.unpatchify(x, grid_sizes) + return [u.float() for u in x] + + def unpatchify(self, x, grid_sizes): + r""" + Reconstruct video tensors from patch embeddings. + + Args: + x (List[Tensor]): + List of patchified features, each with shape [L, C_out * prod(patch_size)], 似乎prod指的是**2, [360, 48 * prod(2)] = [360, 192] + grid_sizes (Tensor): + Original spatial-temporal grid dimensions before patching, + shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches) + + Returns: + List[Tensor]: + Reconstructed video tensors with shape [C_out, F, H / 8, W / 8] + """ + # 开始前,x.shape = [1, 360, 192] + c = self.out_dim + out = [] + # grid_sizes.tolist() = [[3, 12, 10]] + for u, v in zip(x, grid_sizes.tolist()): + # 裁掉多余 token,并 reshape 成 patch 网格,排成原patch的形状 + # 因为有些实现里,序列可能做过 padding 或对齐,所以这里只取前 F_patches * H_patches * W_patches 个 patch + # [F_patches, H_patches, W_patches, pF, pH, pW, C_out], 我猜为:[3, 12, 10, 1, 2, 2, 48] + u = u[: math.prod(v)].view(*v, *self.patch_size, c) + # 交换维度,把 patch 网格和 patch 内部位置交错排列 + # f h w:patch 网格坐标;p q r:patch 内部坐标;c:通道 + # 交换后为:[C_out, F_patches, pF(一个 patch 覆盖多少帧), H_patches, pH(高度维 上,一个 patch 覆盖多少像素), W_patches, pW] + # 下一步要直接 reshape 成完整视频,F_patches 和 pF 合并成完整帧数 F,H_patches 和 pH 合并成完整高度,W_patches 和 pW 合并成完整宽度 + u = torch.einsum("fhwpqrc->cfphqwr", u) + # 下行做了如下操作:[ + # F_patches * pF, + # H_patches * pH, + # W_patches * pW + # ],似的shape最终变成:[C_out, F, H, W] + u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)]) # shape = [48, 3, 24, 20] + out.append(u) # 似乎就只有一个元素,我怀疑和bsz有关 + return out + + def init_weights(self): + r""" + Initialize model parameters using Xavier initialization. + """ + + # basic init + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + + # init embeddings + nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1)) + for m in self.text_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + for m in self.time_embedding.modules(): + if isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.02) + + # init output layer + nn.init.zeros_(self.head.head.weight) diff --git a/lightx2v/models/networks/motus/wan/t5.py b/lightx2v/models/networks/motus/wan/t5.py new file mode 100644 index 000000000..e52c11fff --- /dev/null +++ b/lightx2v/models/networks/motus/wan/t5.py @@ -0,0 +1,403 @@ +# Modified from transformers.models.t5.modeling_t5 +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .tokenizers import HuggingfaceTokenizer + +__all__ = [ + "T5Model", + "T5Encoder", + "T5Decoder", + "T5EncoderModel", +] + + +def fp16_clamp(x): + if x.dtype == torch.float16 and torch.isinf(x).any(): + clamp = torch.finfo(x.dtype).max - 1000 + x = torch.clamp(x, min=-clamp, max=clamp) + return x + + +def init_weights(m): + if isinstance(m, T5LayerNorm): + nn.init.ones_(m.weight) + elif isinstance(m, T5Model): + nn.init.normal_(m.token_embedding.weight, std=1.0) + elif isinstance(m, T5FeedForward): + nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5) + nn.init.normal_(m.fc1.weight, std=m.dim**-0.5) + nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5) + elif isinstance(m, T5Attention): + nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn) ** -0.5) + nn.init.normal_(m.k.weight, std=m.dim**-0.5) + nn.init.normal_(m.v.weight, std=m.dim**-0.5) + nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn) ** -0.5) + elif isinstance(m, T5RelativeEmbedding): + nn.init.normal_(m.embedding.weight, std=(2 * m.num_buckets * m.num_heads) ** -0.5) + + +class GELU(nn.Module): + def forward(self, x): + return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) + + +class T5LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-6): + super(T5LayerNorm, self).__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) + self.eps) + if self.weight.dtype in [torch.float16, torch.bfloat16]: + x = x.type_as(self.weight) + return self.weight * x + + +class T5Attention(nn.Module): + def __init__(self, dim, dim_attn, num_heads, dropout=0.1): + assert dim_attn % num_heads == 0 + super(T5Attention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.num_heads = num_heads + self.head_dim = dim_attn // num_heads + + # layers + self.q = nn.Linear(dim, dim_attn, bias=False) + self.k = nn.Linear(dim, dim_attn, bias=False) + self.v = nn.Linear(dim, dim_attn, bias=False) + self.o = nn.Linear(dim_attn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, context=None, mask=None, pos_bias=None): + """ + x: [B, L1, C]. + context: [B, L2, C] or None. + mask: [B, L2] or [B, L1, L2] or None. + """ + # check inputs + context = x if context is None else context + b, n, c = x.size(0), self.num_heads, self.head_dim + + # compute query, key, value + q = self.q(x).view(b, -1, n, c) + k = self.k(context).view(b, -1, n, c) + v = self.v(context).view(b, -1, n, c) + + # attention bias + attn_bias = x.new_zeros(b, n, q.size(1), k.size(1)) + if pos_bias is not None: + attn_bias += pos_bias + if mask is not None: + assert mask.ndim in [2, 3] + mask = mask.view(b, 1, 1, -1) if mask.ndim == 2 else mask.unsqueeze(1) + attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min) + + # compute attention (T5 does not use scaling) + attn = torch.einsum("binc,bjnc->bnij", q, k) + attn_bias + attn = F.softmax(attn.float(), dim=-1).type_as(attn) + x = torch.einsum("bnij,bjnc->binc", attn, v) + + # output + x = x.reshape(b, -1, n * c) + x = self.o(x) + x = self.dropout(x) + return x + + +class T5FeedForward(nn.Module): + def __init__(self, dim, dim_ffn, dropout=0.1): + super(T5FeedForward, self).__init__() + self.dim = dim + self.dim_ffn = dim_ffn + + # layers + self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU()) + self.fc1 = nn.Linear(dim, dim_ffn, bias=False) + self.fc2 = nn.Linear(dim_ffn, dim, bias=False) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.fc1(x) * self.gate(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.dropout(x) + return x + + +class T5SelfAttention(nn.Module): + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): + super(T5SelfAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) + + def forward(self, x, mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.ffn(self.norm2(x))) + return x + + +class T5CrossAttention(nn.Module): + def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.1): + super(T5CrossAttention, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.norm1 = T5LayerNorm(dim) + self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm2 = T5LayerNorm(dim) + self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout) + self.norm3 = T5LayerNorm(dim) + self.ffn = T5FeedForward(dim, dim_ffn, dropout) + self.pos_embedding = None if shared_pos else T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) + + def forward(self, x, mask=None, encoder_states=None, encoder_mask=None, pos_bias=None): + e = pos_bias if self.shared_pos else self.pos_embedding(x.size(1), x.size(1)) + x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e)) + x = fp16_clamp(x + self.cross_attn(self.norm2(x), context=encoder_states, mask=encoder_mask)) + x = fp16_clamp(x + self.ffn(self.norm3(x))) + return x + + +class T5RelativeEmbedding(nn.Module): + def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128): + super(T5RelativeEmbedding, self).__init__() + self.num_buckets = num_buckets + self.num_heads = num_heads + self.bidirectional = bidirectional + self.max_dist = max_dist + + # layers + self.embedding = nn.Embedding(num_buckets, num_heads) + + def forward(self, lq, lk): + device = self.embedding.weight.device + # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \ + # torch.arange(lq).unsqueeze(1).to(device) + rel_pos = torch.arange(lk, device=device).unsqueeze(0) - torch.arange(lq, device=device).unsqueeze(1) + rel_pos = self._relative_position_bucket(rel_pos) + rel_pos_embeds = self.embedding(rel_pos) + rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(0) # [1, N, Lq, Lk] + return rel_pos_embeds.contiguous() + + def _relative_position_bucket(self, rel_pos): + # preprocess + if self.bidirectional: + num_buckets = self.num_buckets // 2 + rel_buckets = (rel_pos > 0).long() * num_buckets + rel_pos = torch.abs(rel_pos) + else: + num_buckets = self.num_buckets + rel_buckets = 0 + rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos)) + + # embeddings for small and large positions + max_exact = num_buckets // 2 + rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) / math.log(self.max_dist / max_exact) * (num_buckets - max_exact)).long() + rel_pos_large = torch.min(rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1)) + rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large) + return rel_buckets + + +class T5Encoder(nn.Module): + def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): + super(T5Encoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=True) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None): + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Decoder(nn.Module): + def __init__(self, vocab, dim, dim_attn, dim_ffn, num_heads, num_layers, num_buckets, shared_pos=True, dropout=0.1): + super(T5Decoder, self).__init__() + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.num_layers = num_layers + self.num_buckets = num_buckets + self.shared_pos = shared_pos + + # layers + self.token_embedding = vocab if isinstance(vocab, nn.Embedding) else nn.Embedding(vocab, dim) + self.pos_embedding = T5RelativeEmbedding(num_buckets, num_heads, bidirectional=False) if shared_pos else None + self.dropout = nn.Dropout(dropout) + self.blocks = nn.ModuleList([T5CrossAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) for _ in range(num_layers)]) + self.norm = T5LayerNorm(dim) + + # initialize weights + self.apply(init_weights) + + def forward(self, ids, mask=None, encoder_states=None, encoder_mask=None): + b, s = ids.size() + + # causal mask + if mask is None: + mask = torch.tril(torch.ones(1, s, s).to(ids.device)) + elif mask.ndim == 2: + mask = torch.tril(mask.unsqueeze(1).expand(-1, s, -1)) + + # layers + x = self.token_embedding(ids) + x = self.dropout(x) + e = self.pos_embedding(x.size(1), x.size(1)) if self.shared_pos else None + for block in self.blocks: + x = block(x, mask, encoder_states, encoder_mask, pos_bias=e) + x = self.norm(x) + x = self.dropout(x) + return x + + +class T5Model(nn.Module): + def __init__(self, vocab_size, dim, dim_attn, dim_ffn, num_heads, encoder_layers, decoder_layers, num_buckets, shared_pos=True, dropout=0.1): + super(T5Model, self).__init__() + self.vocab_size = vocab_size + self.dim = dim + self.dim_attn = dim_attn + self.dim_ffn = dim_ffn + self.num_heads = num_heads + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.num_buckets = num_buckets + + # layers + self.token_embedding = nn.Embedding(vocab_size, dim) + self.encoder = T5Encoder(self.token_embedding, dim, dim_attn, dim_ffn, num_heads, encoder_layers, num_buckets, shared_pos, dropout) + self.decoder = T5Decoder(self.token_embedding, dim, dim_attn, dim_ffn, num_heads, decoder_layers, num_buckets, shared_pos, dropout) + self.head = nn.Linear(dim, vocab_size, bias=False) + + # initialize weights + self.apply(init_weights) + + def forward(self, encoder_ids, encoder_mask, decoder_ids, decoder_mask): + x = self.encoder(encoder_ids, encoder_mask) + x = self.decoder(decoder_ids, decoder_mask, x, encoder_mask) + x = self.head(x) + return x + + +def _t5(name, encoder_only=False, decoder_only=False, return_tokenizer=False, tokenizer_kwargs={}, dtype=torch.float32, device="cpu", **kwargs): + # sanity check + assert not (encoder_only and decoder_only) + + # params + if encoder_only: + model_cls = T5Encoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("encoder_layers") + _ = kwargs.pop("decoder_layers") + elif decoder_only: + model_cls = T5Decoder + kwargs["vocab"] = kwargs.pop("vocab_size") + kwargs["num_layers"] = kwargs.pop("decoder_layers") + _ = kwargs.pop("encoder_layers") + else: + model_cls = T5Model + + # init model + with torch.device(device): + model = model_cls(**kwargs) + + # set device + model = model.to(dtype=dtype, device=device) + + # init tokenizer + if return_tokenizer: + from .tokenizers import HuggingfaceTokenizer + + tokenizer = HuggingfaceTokenizer(f"google/{name}", **tokenizer_kwargs) + return model, tokenizer + else: + return model + + +def umt5_xxl(**kwargs): + cfg = dict(vocab_size=256384, dim=4096, dim_attn=4096, dim_ffn=10240, num_heads=64, encoder_layers=24, decoder_layers=24, num_buckets=32, shared_pos=False, dropout=0.1) + cfg.update(**kwargs) + return _t5("umt5-xxl", **cfg) + + +class T5EncoderModel: + def __init__( + self, + text_len, + dtype=torch.bfloat16, + device=torch.cuda.current_device(), + checkpoint_path=None, + tokenizer_path=None, + shard_fn=None, + ): + self.text_len = text_len + self.dtype = dtype + self.device = device + self.checkpoint_path = checkpoint_path + self.tokenizer_path = tokenizer_path + + # init model + model = umt5_xxl(encoder_only=True, return_tokenizer=False, dtype=dtype, device=device).eval().requires_grad_(False) + logging.info(f"loading {checkpoint_path}") + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")) + self.model = model + if shard_fn is not None: + self.model = shard_fn(self.model, sync_module_states=False) + else: + self.model.to(self.device) + # init tokenizer + self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=text_len, clean="whitespace") + + def __call__(self, texts, device): + ids, mask = self.tokenizer(texts, return_mask=True, add_special_tokens=True) + ids = ids.to(device) + mask = mask.to(device) + seq_lens = mask.gt(0).sum(dim=1).long() + context = self.model(ids, mask) + return [u[:v] for u, v in zip(context, seq_lens)] diff --git a/lightx2v/models/networks/motus/wan/tokenizers.py b/lightx2v/models/networks/motus/wan/tokenizers.py new file mode 100644 index 000000000..36f72caa7 --- /dev/null +++ b/lightx2v/models/networks/motus/wan/tokenizers.py @@ -0,0 +1,60 @@ +import html +import string + +import ftfy +import regex as re +from transformers import AutoTokenizer + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + return text.strip() + + +def canonicalize(text, keep_punctuation_exact_string=None): + text = text.replace("_", " ") + if keep_punctuation_exact_string: + text = keep_punctuation_exact_string.join(part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string)) + else: + text = text.translate(str.maketrans("", "", string.punctuation)) + text = text.lower() + text = re.sub(r"\s+", " ", text) + return text.strip() + + +class HuggingfaceTokenizer: + def __init__(self, name, seq_len=None, clean=None, **kwargs): + assert clean in (None, "whitespace", "lower", "canonicalize") + self.name = name + self.seq_len = seq_len + self.clean = clean + self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs) + self.vocab_size = self.tokenizer.vocab_size + + def __call__(self, sequence, **kwargs): + return_mask = kwargs.pop("return_mask", False) + local_kwargs = {"return_tensors": "pt"} + if self.seq_len is not None: + local_kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len}) + local_kwargs.update(**kwargs) + if isinstance(sequence, str): + sequence = [sequence] + if self.clean: + sequence = [self._clean(item) for item in sequence] + ids = self.tokenizer(sequence, **local_kwargs) + return (ids.input_ids, ids.attention_mask) if return_mask else ids.input_ids + + def _clean(self, text): + if self.clean == "whitespace": + text = whitespace_clean(basic_clean(text)) + elif self.clean == "lower": + text = whitespace_clean(basic_clean(text)).lower() + elif self.clean == "canonicalize": + text = canonicalize(basic_clean(text)) + return text diff --git a/lightx2v/models/networks/motus/wan/vae2_2.py b/lightx2v/models/networks/motus/wan/vae2_2.py new file mode 100644 index 000000000..9bfaf32a4 --- /dev/null +++ b/lightx2v/models/networks/motus/wan/vae2_2.py @@ -0,0 +1,985 @@ +# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved. +import logging + +import torch +import torch.cuda.amp as amp +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +__all__ = [ + "Wan2_2_VAE", +] + +CACHE_T = 2 + + +class CausalConv3d(nn.Conv3d): + """ + Causal 3d convolusion. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._padding = ( + self.padding[2], + self.padding[2], + self.padding[1], + self.padding[1], + 2 * self.padding[0], + 0, + ) + self.padding = (0, 0, 0) + + def forward(self, x, cache_x=None): + padding = list(self._padding) + if cache_x is not None and self._padding[4] > 0: + cache_x = cache_x.to(x.device) + x = torch.cat([cache_x, x], dim=2) + padding[4] -= cache_x.shape[2] + x = F.pad(x, padding) + + return super().forward(x) + + +class RMS_norm(nn.Module): + def __init__(self, dim, channel_first=True, images=True, bias=False): + super().__init__() + broadcastable_dims = (1, 1, 1) if not images else (1, 1) + shape = (dim, *broadcastable_dims) if channel_first else (dim,) + + self.channel_first = channel_first + self.scale = dim**0.5 + self.gamma = nn.Parameter(torch.ones(shape)) + self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0 + + def forward(self, x): + return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias + + +class Upsample(nn.Upsample): + def forward(self, x): + """ + Fix bfloat16 support for nearest neighbor interpolation. + """ + return super().forward(x.float()).type_as(x) + + +class Resample(nn.Module): + def __init__(self, dim, mode): + assert mode in ( + "none", + "upsample2d", + "upsample3d", + "downsample2d", + "downsample3d", + ) + super().__init__() + self.dim = dim + self.mode = mode + + # layers + if mode == "upsample2d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + ) + elif mode == "upsample3d": + self.resample = nn.Sequential( + Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), + nn.Conv2d(dim, dim, 3, padding=1), + # nn.Conv2d(dim, dim//2, 3, padding=1) + ) + self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0)) + elif mode == "downsample2d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + elif mode == "downsample3d": + self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))) + self.time_conv = CausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)) + else: + self.resample = nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + b, c, t, h, w = x.size() + if self.mode == "upsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = "Rep" + feat_idx[0] += 1 + else: + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep": + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep": + cache_x = torch.cat( + [torch.zeros_like(cache_x).to(cache_x.device), cache_x], + dim=2, + ) + if feat_cache[idx] == "Rep": + x = self.time_conv(x) + else: + x = self.time_conv(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + x = x.reshape(b, 2, c, t, h, w) + x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3) + x = x.reshape(b, c, t * 2, h, w) + t = x.shape[2] + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.resample(x) + x = rearrange(x, "(b t) c h w -> b c t h w", t=t) + + if self.mode == "downsample3d": + if feat_cache is not None: + idx = feat_idx[0] + if feat_cache[idx] is None: + feat_cache[idx] = x.clone() + feat_idx[0] += 1 + else: + cache_x = x[:, :, -1:, :, :].clone() + x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + return x + + def init_weight(self, conv): + conv_weight = conv.weight.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + one_matrix = torch.eye(c1, c2) + init_matrix = one_matrix + nn.init.zeros_(conv_weight) + conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5 + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + def init_weight2(self, conv): + conv_weight = conv.weight.data.detach().clone() + nn.init.zeros_(conv_weight) + c1, c2, t, h, w = conv_weight.size() + init_matrix = torch.eye(c1 // 2, c2) + conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix + conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix + conv.weight = nn.Parameter(conv_weight) + nn.init.zeros_(conv.bias.data) + + +class ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout=0.0): + super().__init__() + self.in_dim = in_dim + self.out_dim = out_dim + + # layers + self.residual = nn.Sequential( + RMS_norm(in_dim, images=False), + nn.SiLU(), + CausalConv3d(in_dim, out_dim, 3, padding=1), + RMS_norm(out_dim, images=False), + nn.SiLU(), + nn.Dropout(dropout), + CausalConv3d(out_dim, out_dim, 3, padding=1), + ) + self.shortcut = CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity() + + def forward(self, x, feat_cache=None, feat_idx=[0]): + h = self.shortcut(x) + for layer in self.residual: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + # cache last frame of last two chunk + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + h + + +class AttentionBlock(nn.Module): + """ + Causal self-attention with a single head. + """ + + def __init__(self, dim): + super().__init__() + self.dim = dim + + # layers + self.norm = RMS_norm(dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x): + identity = x + b, c, t, h, w = x.size() + x = rearrange(x, "b c t h w -> (b t) c h w") + x = self.norm(x) + # compute query, key, value + q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3, -1).permute(0, 1, 3, 2).contiguous().chunk(3, dim=-1) + + # apply attention + x = F.scaled_dot_product_attention( + q, + k, + v, + ) + x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w) + + # output + x = self.proj(x) + x = rearrange(x, "(b t) c h w-> b c t h w", t=t) + return x + identity + + +def patchify(x, patch_size): + if patch_size == 1: + return x + if x.dim() == 4: + x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b c f (h q) (w r) -> b (c r q) f h w", + q=patch_size, + r=patch_size, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size): + if patch_size == 1: + return x + + if x.dim() == 4: + x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size) + elif x.dim() == 5: + x = rearrange( + x, + "b (c r q) f h w -> b c f (h q) (w r)", + q=patch_size, + r=patch_size, + ) + return x + + +class AvgDown3D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert in_channels * self.factor % out_channels == 0 + self.group_size = in_channels * self.factor // out_channels + + def forward(self, x: torch.Tensor) -> torch.Tensor: + pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t + pad = (0, 0, 0, 0, pad_t, 0) + x = F.pad(x, pad) + B, C, T, H, W = x.shape + x = x.view( + B, + C, + T // self.factor_t, + self.factor_t, + H // self.factor_s, + self.factor_s, + W // self.factor_s, + self.factor_s, + ) + x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous() + x = x.view( + B, + C * self.factor, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.view( + B, + self.out_channels, + self.group_size, + T // self.factor_t, + H // self.factor_s, + W // self.factor_s, + ) + x = x.mean(dim=2) + return x + + +class DupUp3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + factor_t, + factor_s=1, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.factor_t = factor_t + self.factor_s = factor_s + self.factor = self.factor_t * self.factor_s * self.factor_s + + assert out_channels * self.factor % in_channels == 0 + self.repeats = out_channels * self.factor // in_channels + + def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor: + x = x.repeat_interleave(self.repeats, dim=1) + x = x.view( + x.size(0), + self.out_channels, + self.factor_t, + self.factor_s, + self.factor_s, + x.size(2), + x.size(3), + x.size(4), + ) + x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() + x = x.view( + x.size(0), + self.out_channels, + x.size(2) * self.factor_t, + x.size(4) * self.factor_s, + x.size(6) * self.factor_s, + ) + if first_chunk: + x = x[:, :, self.factor_t - 1 :, :, :] + return x + + +class Down_ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout, mult, temperal_downsample=False, down_flag=False): + super().__init__() + + # Shortcut path with downsample + self.avg_shortcut = AvgDown3D( + in_dim, + out_dim, + factor_t=2 if temperal_downsample else 1, + factor_s=2 if down_flag else 1, + ) + + # Main path with residual blocks and downsample + downsamples = [] + for _ in range(mult): + downsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final downsample block + if down_flag: + mode = "downsample3d" if temperal_downsample else "downsample2d" + downsamples.append(Resample(out_dim, mode=mode)) + + self.downsamples = nn.Sequential(*downsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + x_copy = x.clone() + for module in self.downsamples: + x = module(x, feat_cache, feat_idx) + + return x + self.avg_shortcut(x_copy) + + +class Up_ResidualBlock(nn.Module): + def __init__(self, in_dim, out_dim, dropout, mult, temperal_upsample=False, up_flag=False): + super().__init__() + # Shortcut path with upsample + if up_flag: + self.avg_shortcut = DupUp3D( + in_dim, + out_dim, + factor_t=2 if temperal_upsample else 1, + factor_s=2 if up_flag else 1, + ) + else: + self.avg_shortcut = None + + # Main path with residual blocks and upsample + upsamples = [] + for _ in range(mult): + upsamples.append(ResidualBlock(in_dim, out_dim, dropout)) + in_dim = out_dim + + # Add the final upsample block + if up_flag: + mode = "upsample3d" if temperal_upsample else "upsample2d" + upsamples.append(Resample(out_dim, mode=mode)) + + self.upsamples = nn.Sequential(*upsamples) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + x_main = x.clone() + for module in self.upsamples: + x_main = module(x_main, feat_cache, feat_idx) + if self.avg_shortcut is not None: + x_shortcut = self.avg_shortcut(x, first_chunk) + return x_main + x_shortcut + else: + return x_main + + +class Encoder3d(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + + # dimensions + dims = [dim * u for u in [1] + dim_mult] + scale = 1.0 + + # init block + self.conv1 = CausalConv3d(12, dims[0], 3, padding=1) + + # downsample blocks + downsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_down_flag = temperal_downsample[i] if i < len(temperal_downsample) else False + downsamples.append( + Down_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks, + temperal_downsample=t_down_flag, + down_flag=i != len(dim_mult) - 1, + ) + ) + scale /= 2.0 + self.downsamples = nn.Sequential(*downsamples) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(out_dim, out_dim, dropout), + AttentionBlock(out_dim), + ResidualBlock(out_dim, out_dim, dropout), + ) + + # # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, z_dim, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0]): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + ## downsamples + for layer in self.downsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## middle + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + + return x + + +class Decoder3d(nn.Module): + def __init__( + self, + dim=128, + z_dim=4, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_upsample=[False, True, True], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_upsample = temperal_upsample + + # dimensions + dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + scale = 1.0 / 2 ** (len(dim_mult) - 2) + # init block + self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1) + + # middle blocks + self.middle = nn.Sequential( + ResidualBlock(dims[0], dims[0], dropout), + AttentionBlock(dims[0]), + ResidualBlock(dims[0], dims[0], dropout), + ) + + # upsample blocks + upsamples = [] + for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])): + t_up_flag = temperal_upsample[i] if i < len(temperal_upsample) else False + upsamples.append( + Up_ResidualBlock( + in_dim=in_dim, + out_dim=out_dim, + dropout=dropout, + mult=num_res_blocks + 1, + temperal_upsample=t_up_flag, + up_flag=i != len(dim_mult) - 1, + ) + ) + self.upsamples = nn.Sequential(*upsamples) + + # output blocks + self.head = nn.Sequential( + RMS_norm(out_dim, images=False), + nn.SiLU(), + CausalConv3d(out_dim, 12, 3, padding=1), + ) + + def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False): + if feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = self.conv1(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = self.conv1(x) + + for layer in self.middle: + if isinstance(layer, ResidualBlock) and feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + ## upsamples + for layer in self.upsamples: + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, first_chunk) + else: + x = layer(x) + + ## head + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + idx = feat_idx[0] + cache_x = x[:, :, -CACHE_T:, :, :].clone() + if cache_x.shape[2] < 2 and feat_cache[idx] is not None: + cache_x = torch.cat( + [ + feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), + cache_x, + ], + dim=2, + ) + x = layer(x, feat_cache[idx]) + feat_cache[idx] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + return x + + +def count_conv3d(model): + count = 0 + for m in model.modules(): + if isinstance(m, CausalConv3d): + count += 1 + return count + + +class WanVAE_(nn.Module): + def __init__( + self, + dim=160, + dec_dim=256, + z_dim=16, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, False], + dropout=0.0, + ): + super().__init__() + self.dim = dim + self.z_dim = z_dim + self.dim_mult = dim_mult + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.temperal_downsample = temperal_downsample + self.temperal_upsample = temperal_downsample[::-1] + + # modules + self.encoder = Encoder3d( + dim, + z_dim * 2, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_downsample, + dropout, + ) + self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1) + self.conv2 = CausalConv3d(z_dim, z_dim, 1) + self.decoder = Decoder3d( + dec_dim, + z_dim, + dim_mult, + num_res_blocks, + attn_scales, + self.temperal_upsample, + dropout, + ) + + def forward(self, x, scale=[0, 1]): + mu = self.encode(x, scale) + x_recon = self.decode(mu, scale) + return x_recon, mu + + def encode(self, x, scale): + self.clear_cache() + x = patchify(x, patch_size=2) + t = x.shape[2] + iter_ = 1 + (t - 1) // 4 + for i in range(iter_): + self._enc_conv_idx = [0] + if i == 0: + out = self.encoder( + x[:, :, :1, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + else: + out_ = self.encoder( + x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :], + feat_cache=self._enc_feat_map, + feat_idx=self._enc_conv_idx, + ) + out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) + if isinstance(scale[0], torch.Tensor): + mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(1, self.z_dim, 1, 1, 1) + else: + mu = (mu - scale[0]) * scale[1] + self.clear_cache() + return mu + + def decode(self, z, scale): + self.clear_cache() + if isinstance(scale[0], torch.Tensor): + z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1) + else: + z = z / scale[1] + scale[0] + iter_ = z.shape[2] + x = self.conv2(z) + for i in range(iter_): + self._conv_idx = [0] + if i == 0: + out = self.decoder( + x[:, :, i : i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + first_chunk=True, + ) + else: + out_ = self.decoder( + x[:, :, i : i + 1, :, :], + feat_cache=self._feat_map, + feat_idx=self._conv_idx, + ) + out = torch.cat([out, out_], 2) + out = unpatchify(out, patch_size=2) + self.clear_cache() + return out + + def reparameterize(self, mu, log_var): + std = torch.exp(0.5 * log_var) + eps = torch.randn_like(std) + return eps * std + mu + + def sample(self, imgs, deterministic=False): + mu, log_var = self.encode(imgs) + if deterministic: + return mu + std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0)) + return mu + std * torch.randn_like(std) + + def clear_cache(self): + self._conv_num = count_conv3d(self.decoder) + self._conv_idx = [0] + self._feat_map = [None] * self._conv_num + # cache encode + self._enc_conv_num = count_conv3d(self.encoder) + self._enc_conv_idx = [0] + self._enc_feat_map = [None] * self._enc_conv_num + + +def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs): + # params + cfg = dict( + dim=dim, + z_dim=z_dim, + dim_mult=[1, 2, 4, 4], + num_res_blocks=2, + attn_scales=[], + temperal_downsample=[True, True, True], + dropout=0.0, + ) + cfg.update(**kwargs) + + # init model + with torch.device("meta"): + model = WanVAE_(**cfg) + + # load checkpoint + logging.info(f"loading {pretrained_path}") + model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True) + + return model + + +class Wan2_2_VAE: + def __init__( + self, + z_dim=48, + c_dim=160, + vae_pth=None, + dim_mult=[1, 2, 4, 4], + temperal_downsample=[False, True, True], + dtype=torch.float, + device="cuda", + ): + self.dtype = dtype + self.device = device + + mean = torch.tensor( + [ + -0.2289, + -0.0052, + -0.1323, + -0.2339, + -0.2799, + 0.0174, + 0.1838, + 0.1557, + -0.1382, + 0.0542, + 0.2813, + 0.0891, + 0.1570, + -0.0098, + 0.0375, + -0.1825, + -0.2246, + -0.1207, + -0.0698, + 0.5109, + 0.2665, + -0.2108, + -0.2158, + 0.2502, + -0.2055, + -0.0322, + 0.1109, + 0.1567, + -0.0729, + 0.0899, + -0.2799, + -0.1230, + -0.0313, + -0.1649, + 0.0117, + 0.0723, + -0.2839, + -0.2083, + -0.0520, + 0.3748, + 0.0152, + 0.1957, + 0.1433, + -0.2944, + 0.3573, + -0.0548, + -0.1681, + -0.0667, + ], + dtype=dtype, + device=device, + ) + std = torch.tensor( + [ + 0.4765, + 1.0364, + 0.4514, + 1.1677, + 0.5313, + 0.4990, + 0.4818, + 0.5013, + 0.8158, + 1.0344, + 0.5894, + 1.0901, + 0.6885, + 0.6165, + 0.8454, + 0.4978, + 0.5759, + 0.3523, + 0.7135, + 0.6804, + 0.5833, + 1.4146, + 0.8986, + 0.5659, + 0.7069, + 0.5338, + 0.4889, + 0.4917, + 0.4069, + 0.4999, + 0.6866, + 0.4093, + 0.5709, + 0.6065, + 0.6415, + 0.4944, + 0.5726, + 1.2042, + 0.5458, + 1.6887, + 0.3971, + 1.0600, + 0.3943, + 0.5537, + 0.5444, + 0.4089, + 0.7468, + 0.7744, + ], + dtype=dtype, + device=device, + ) + self.scale = [mean, 1.0 / std] + + # init model + self.model = ( + _video_vae( + pretrained_path=vae_pth, + z_dim=z_dim, + dim=c_dim, + dim_mult=dim_mult, + temperal_downsample=temperal_downsample, + ) + .eval() + .requires_grad_(False) + .to(device) + ) + + def encode(self, videos): + with torch.amp.autocast("cuda", dtype=self.dtype): + return self.model.encode(videos, self.scale) + + def decode(self, zs): + try: + if not isinstance(zs, list): + raise TypeError("zs should be a list") + with amp.autocast(dtype=self.dtype): + return [self.model.decode(u.unsqueeze(0), self.scale).float().clamp_(-1, 1).squeeze(0) for u in zs] + # TODO: maybe can speed up with batch + # return self.model.decode(video_latents, self.scale).float().clamp(-1, 1) + except TypeError as e: + logging.info(e) + return None diff --git a/lightx2v/models/networks/motus/wan_model.py b/lightx2v/models/networks/motus/wan_model.py new file mode 100644 index 000000000..7a63ba5bb --- /dev/null +++ b/lightx2v/models/networks/motus/wan_model.py @@ -0,0 +1,80 @@ +import json +import logging +import os +from typing import Any, Dict, Optional + +import torch +import torch.nn as nn + +from lightx2v.models.networks.motus.wan.model import WanModel +from lightx2v.models.networks.motus.wan.vae2_2 import Wan2_2_VAE + +try: + from safetensors.torch import load_file as safe_load_file +except Exception: + safe_load_file = None + +logger = logging.getLogger(__name__) + + +def _strip_known_prefixes_for_wan(sd: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + if not isinstance(sd, dict): + return sd + if not any(key.startswith("dit.") for key in sd.keys()): + return sd + return {(key[4:] if key.startswith("dit.") else key): value for key, value in sd.items()} + + +class WanVideoModel(nn.Module): + def __init__(self, model_config: Dict[str, Any], vae_path: str, device: str = "cuda", precision: str = "bfloat16"): + super().__init__() + self.device = torch.device(device) + self.precision = {"float32": torch.float32, "float16": torch.float16, "bfloat16": torch.bfloat16}[precision] + self.wan_model = WanModel(**model_config) + self.wan_model.to(device=self.device, dtype=self.precision) + self.vae = Wan2_2_VAE(vae_pth=vae_path, device=self.device) + + def encode_video(self, video_pixels: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + return self.vae.encode(video_pixels) + + def decode_video(self, video_latents: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + return torch.stack([self.vae.decode([video_latents[i]])[0] for i in range(video_latents.shape[0])], dim=0) + # TODO: maybe can speed up with batch to tensor + # return self.vae.model.decode(video_latents, self.vae.scale).float().clamp(-1, 1) + + @classmethod + def from_config(cls, config_path: str, vae_path: str, device: str = "cuda", precision: str = "bfloat16"): + config_json_path = os.path.join(config_path, "config.json") + with open(config_json_path, "r") as file: + model_config = json.load(file) + return cls(model_config=model_config, vae_path=vae_path, device=device, precision=precision) + + @classmethod + def from_pretrained(cls, checkpoint_path: str, vae_path: str, config_path: Optional[str] = None, device: str = "cuda", precision: str = "bfloat16"): + config_path = config_path or checkpoint_path + config_json_path = os.path.join(config_path, "config.json") + with open(config_json_path, "r") as file: + model_config = json.load(file) + model = cls(model_config=model_config, vae_path=vae_path, device=device, precision=precision) + + if checkpoint_path.endswith(".pt"): + loaded = torch.load(checkpoint_path, map_location="cpu") + state_dict = loaded["model"] if isinstance(loaded, dict) and "model" in loaded else loaded + elif checkpoint_path.endswith(".bin") or checkpoint_path.endswith(".safetensors"): + if checkpoint_path.endswith(".safetensors"): + if safe_load_file is None: + raise RuntimeError("safetensors is not installed") + state_dict = safe_load_file(checkpoint_path, device="cpu") + else: + loaded = torch.load(checkpoint_path, map_location="cpu") + state_dict = loaded.get("state_dict", loaded.get("model", loaded)) if isinstance(loaded, dict) else loaded + else: + loaded_model = WanModel.from_pretrained(checkpoint_path) + model.wan_model.load_state_dict(loaded_model.state_dict(), strict=False) + return model + + state_dict = _strip_known_prefixes_for_wan(state_dict) + model.wan_model.load_state_dict(state_dict, strict=False) + return model diff --git a/lightx2v/models/runners/motus/__init__.py b/lightx2v/models/runners/motus/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightx2v/models/runners/motus/motus_runner.py b/lightx2v/models/runners/motus/motus_runner.py new file mode 100644 index 000000000..384623dc0 --- /dev/null +++ b/lightx2v/models/runners/motus/motus_runner.py @@ -0,0 +1,94 @@ +import json +from pathlib import Path + +import numpy as np +import torch +from loguru import logger + +from lightx2v.models.networks.motus.model import MotusModel +from lightx2v.models.runners.base_runner import BaseRunner +from lightx2v.server.metrics import monitor_cli +from lightx2v.utils.profiler import * +from lightx2v.utils.registry_factory import RUNNER_REGISTER +from lightx2v.utils.utils import save_to_video +from lightx2v_platform.base.global_var import AI_DEVICE + + +@RUNNER_REGISTER("motus") +class MotusRunner(BaseRunner): + def __init__(self, config): + super().__init__(config) + self.device = torch.device(AI_DEVICE) + self.adapter = None + + @ProfilingContext4DebugL2("Load models") + def init_modules(self): + self.adapter = MotusModel(self.config, self.device) + + def _load_state_value(self, state_path: str): + state_path = str(Path(state_path).expanduser().resolve()) + suffix = Path(state_path).suffix.lower() + if suffix == ".npy": + return np.load(state_path) + if suffix in [".pt", ".pth"]: + value = torch.load(state_path, map_location="cpu") + if isinstance(value, dict): + for key in ["state", "qpos", "joint_state", "initial_state"]: + if key in value: + return value[key] + return value + if suffix == ".json": + with open(state_path, "r") as f: + value = json.load(f) + if isinstance(value, dict): + for key in ["state", "qpos", "joint_state", "initial_state"]: + if key in value: + return value[key] + return value + if suffix in [".txt", ".csv"]: + text = Path(state_path).read_text().strip().replace("\n", ",") + return [float(item) for item in text.split(",") if item.strip()] + raise ValueError(f"Unsupported state file format: {state_path}") + + def _resolve_action_output_path(self): + if self.input_info.save_action_path: + return str(Path(self.input_info.save_action_path).expanduser().resolve()) + return str(Path(self.input_info.save_result_path).expanduser().resolve().with_suffix(".actions.json")) + + def _save_outputs(self, pred_frames: torch.Tensor, pred_actions: torch.Tensor): + video_path = str(Path(self.input_info.save_result_path).expanduser().resolve()) + action_path = self._resolve_action_output_path() + + video = pred_frames.clamp(0, 1).permute(0, 2, 3, 1).contiguous() + save_to_video(video, video_path, fps=float(self.config.get("fps", 4)), method="ffmpeg") + + Path(action_path).parent.mkdir(parents=True, exist_ok=True) + with open(action_path, "w") as f: + json.dump(pred_actions.detach().cpu().float().tolist(), f, ensure_ascii=False, indent=2) + + logger.info(f"Saved Motus video to {video_path}") + logger.info(f"Saved Motus actions to {action_path}") + + @ProfilingContext4DebugL1("RUN pipeline", recorder_mode=GET_RECORDER_MODE(), metrics_func=monitor_cli.lightx2v_worker_request_duration, metrics_labels=["MotusRunner"]) + def run_pipeline(self, input_info): + self.input_info = input_info + if self.adapter is None: + self.init_modules() + + if not self.input_info.image_path: + raise ValueError("Motus requires `image_path`.") + if not self.input_info.state_path: + raise ValueError("Motus requires `state_path`.") + if not self.input_info.prompt: + raise ValueError("Motus requires `prompt`.") + if not self.input_info.save_result_path: + raise ValueError("Motus requires `save_result_path`.") + + state_value = self._load_state_value(self.input_info.state_path) + pred_frames, pred_actions = self.adapter.infer( + image_path=self.input_info.image_path, + prompt=self.input_info.prompt, + state_value=state_value, + num_inference_steps=int(self.config.get("num_inference_steps", 10)), + ) + self._save_outputs(pred_frames, pred_actions) diff --git a/lightx2v/models/schedulers/motus/__init__.py b/lightx2v/models/schedulers/motus/__init__.py new file mode 100644 index 000000000..f20aaf60b --- /dev/null +++ b/lightx2v/models/schedulers/motus/__init__.py @@ -0,0 +1,3 @@ +from .scheduler import MotusScheduler + +__all__ = ["MotusScheduler"] diff --git a/lightx2v/models/schedulers/motus/scheduler.py b/lightx2v/models/schedulers/motus/scheduler.py new file mode 100644 index 000000000..faf198ada --- /dev/null +++ b/lightx2v/models/schedulers/motus/scheduler.py @@ -0,0 +1,39 @@ +import torch + +from lightx2v.models.schedulers.scheduler import BaseScheduler + + +class MotusScheduler(BaseScheduler): + def __init__(self, config): + super().__init__(config) + self.video_latents = None + self.action_latents = None + self.timesteps = None + + def prepare(self, seed, condition_frame_latent, action_shape, dtype, device): + batch, channels, _, latent_h, latent_w = condition_frame_latent.shape + total_latent_frames = 1 + self.config["num_video_frames"] // 4 + generator = None if seed is None else torch.Generator(device=device).manual_seed(seed) + + self.video_latents = torch.randn( + (batch, channels, total_latent_frames, latent_h, latent_w), + device=device, + dtype=dtype, + generator=generator, + ) + self.video_latents[:, :, 0:1] = condition_frame_latent + self.action_latents = torch.randn(action_shape, device=device, dtype=dtype, generator=generator) + self.timesteps = torch.linspace(1.0, 0.0, self.infer_steps + 1, device=device, dtype=dtype) + self.latents = self.video_latents + + def iter_steps(self): + for step_index in range(self.infer_steps): + t = self.timesteps[step_index] + t_next = self.timesteps[step_index + 1] + yield step_index, t, t_next, t_next - t + + def step(self, video_velocity, action_velocity, dt, condition_frame_latent): + self.video_latents = self.video_latents + video_velocity * dt + self.action_latents = self.action_latents + action_velocity * dt + self.video_latents[:, :, 0:1] = condition_frame_latent + self.latents = self.video_latents diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index ba5e7d64d..65f009614 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -46,6 +46,9 @@ class I2VInputInfo: pose: str = field(default_factory=lambda: None) # Lingbot i2v camera/action conditioning (optional) action_path: str = field(default_factory=str) + # Motus i2v action_expert conditioning (optional) + state_path: str = field(default_factory=str) + save_action_path: str = field(default_factory=str) @dataclass diff --git a/scripts/motus/example_inputs/first_frame.png b/scripts/motus/example_inputs/first_frame.png new file mode 100644 index 000000000..1c3c39c5e Binary files /dev/null and b/scripts/motus/example_inputs/first_frame.png differ diff --git a/scripts/motus/example_inputs/state.npy b/scripts/motus/example_inputs/state.npy new file mode 100644 index 000000000..16842ec92 Binary files /dev/null and b/scripts/motus/example_inputs/state.npy differ diff --git a/scripts/motus/run_motus_i2v.sh b/scripts/motus/run_motus_i2v.sh new file mode 100644 index 000000000..58a68f895 --- /dev/null +++ b/scripts/motus/run_motus_i2v.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=/path/to/LightX2V +model_path=/path/to/MotusModel + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls motus \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/motus/motus_i2v.json \ +--image_path "/path/to/the/first/frame: example_inputs/frist_frame.png" \ +--state_path "/path/to/the/state/at/the/first/frame: example_inputs/state.npy" \ +--prompt "Example prompt: The whole scene is in a realistic, industrial art style with three views: a fixed rear camera, a movable left arm camera, and a movable right arm camera. The aloha robot is currently performing the following task: Pick the bottle with ridges near base head-up using the right arm" \ +--save_result_path ${lightx2v_path}/save_results/output_motus.mp4 \ +--save_action_path ${lightx2v_path}/save_results/output_motus.actions.json \ +--seed 42