From 8029c3ced752c0b235ee78c96bce34076cf0cc43 Mon Sep 17 00:00:00 2001 From: gushiqiao <975033167@qq.com> Date: Tue, 30 Jun 2026 04:27:29 +0000 Subject: [PATCH] feat(cosmos3): add Cosmos3 Super Omni inference tasks --- ...mos3_super_omni_action_fd_agibotworld.json | 28 + ...omni_action_fd_agibotworld_multichunk.json | 30 + .../cosmos3_super_omni_action_id_av.json | 27 + configs/cosmos3/cosmos3_super_omni_i2av.json | 23 + configs/cosmos3/cosmos3_super_omni_i2v.json | 22 + configs/cosmos3/cosmos3_super_omni_t2av.json | 23 + configs/cosmos3/cosmos3_super_omni_t2v.json | 22 + configs/cosmos3/cosmos3_super_t2v.json | 22 + lightx2v/infer.py | 5 + lightx2v/models/audio_encoders/__init__.py | 0 lightx2v/models/audio_encoders/hf/__init__.py | 0 .../audio_encoders/hf/cosmos3/__init__.py | 0 .../hf/cosmos3/sound_tokenizer.py | 138 +++++ .../networks/cosmos3/infer/module_io.py | 15 + .../networks/cosmos3/infer/post_infer.py | 46 +- .../networks/cosmos3/infer/pre_infer.py | 155 ++++- .../models/networks/cosmos3/infer/utils.py | 9 +- lightx2v/models/networks/cosmos3/model.py | 57 +- .../networks/cosmos3/weights/post_weights.py | 12 + .../networks/cosmos3/weights/pre_weights.py | 34 +- .../models/runners/cosmos3/cosmos3_runner.py | 577 ++++++++++++++++-- .../models/schedulers/cosmos3/scheduler.py | 68 +++ .../models/video_encoders/hf/cosmos3/vae.py | 5 +- lightx2v/pipeline.py | 4 + lightx2v/utils/input_info.py | 10 + lightx2v/utils/set_config.py | 7 +- ...osmos3_super_omni_action_fd_agibotworld.sh | 23 + ...r_omni_action_fd_agibotworld_multichunk.sh | 23 + .../cosmos3_super_omni_action_id_av.sh | 22 + scripts/cosmos3/cosmos3_super_omni_i2av.sh | 24 + scripts/cosmos3/cosmos3_super_omni_i2v.sh | 24 + scripts/cosmos3/cosmos3_super_omni_t2av.sh | 22 + scripts/cosmos3/cosmos3_super_omni_t2v.sh | 22 + 33 files changed, 1427 insertions(+), 72 deletions(-) create mode 100644 configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld.json create mode 100644 configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld_multichunk.json create mode 100644 configs/cosmos3/cosmos3_super_omni_action_id_av.json create mode 100644 configs/cosmos3/cosmos3_super_omni_i2av.json create mode 100644 configs/cosmos3/cosmos3_super_omni_i2v.json create mode 100644 configs/cosmos3/cosmos3_super_omni_t2av.json create mode 100644 configs/cosmos3/cosmos3_super_omni_t2v.json create mode 100644 configs/cosmos3/cosmos3_super_t2v.json create mode 100644 lightx2v/models/audio_encoders/__init__.py create mode 100644 lightx2v/models/audio_encoders/hf/__init__.py create mode 100644 lightx2v/models/audio_encoders/hf/cosmos3/__init__.py create mode 100644 lightx2v/models/audio_encoders/hf/cosmos3/sound_tokenizer.py create mode 100644 scripts/cosmos3/cosmos3_super_omni_action_fd_agibotworld.sh create mode 100644 scripts/cosmos3/cosmos3_super_omni_action_fd_agibotworld_multichunk.sh create mode 100644 scripts/cosmos3/cosmos3_super_omni_action_id_av.sh create mode 100644 scripts/cosmos3/cosmos3_super_omni_i2av.sh create mode 100644 scripts/cosmos3/cosmos3_super_omni_i2v.sh create mode 100644 scripts/cosmos3/cosmos3_super_omni_t2av.sh create mode 100644 scripts/cosmos3/cosmos3_super_omni_t2v.sh diff --git a/configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld.json b/configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld.json new file mode 100644 index 000000000..63724a4a3 --- /dev/null +++ b/configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld.json @@ -0,0 +1,28 @@ +{ + "infer_steps": 30, + "sample_guide_scale": 1.0, + "sample_shift": 10.0, + "target_height": 720, + "target_width": 640, + "target_video_length": 17, + "target_fps": 10.0, + "enable_cfg": true, + "action_mode": "forward_dynamics", + "domain_name": "agibotworld", + "view_point": "concat_view", + "action_chunk_size": 16, + "action_chunk_index": 0, + "feature_caching": "NoCaching", + "rms_norm_type": "one-pass", + "attn_rms_norm_type": "one-pass", + "rope_type": "triton", + "self_attn_type": "flash_attn3", + "causal_self_attn_type": "flash_attn3", + "add_resolution_template": false, + "add_duration_template": false, + "use_system_prompt": false, + "cosmos3_meta_init": true, + "vae_cpu_offload": false, + "cpu_offload": false, + "offload_granularity": "block" +} diff --git a/configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld_multichunk.json b/configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld_multichunk.json new file mode 100644 index 000000000..b62d93fd5 --- /dev/null +++ b/configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld_multichunk.json @@ -0,0 +1,30 @@ +{ + "infer_steps": 30, + "sample_guide_scale": 1.0, + "sample_shift": 10.0, + "target_height": 720, + "target_width": 640, + "target_video_length": 17, + "target_fps": 10.0, + "enable_cfg": true, + "action_mode": "forward_dynamics", + "domain_name": "agibotworld", + "view_point": "concat_view", + "action_chunk_size": 16, + "action_chunk_index": 0, + "action_multichunk": true, + "action_num_chunks": 4, + "feature_caching": "NoCaching", + "rms_norm_type": "one-pass", + "attn_rms_norm_type": "one-pass", + "rope_type": "triton", + "self_attn_type": "flash_attn3", + "causal_self_attn_type": "flash_attn3", + "add_resolution_template": false, + "add_duration_template": false, + "use_system_prompt": false, + "cosmos3_meta_init": true, + "vae_cpu_offload": false, + "cpu_offload": false, + "offload_granularity": "block" +} diff --git a/configs/cosmos3/cosmos3_super_omni_action_id_av.json b/configs/cosmos3/cosmos3_super_omni_action_id_av.json new file mode 100644 index 000000000..05d05ecdf --- /dev/null +++ b/configs/cosmos3/cosmos3_super_omni_action_id_av.json @@ -0,0 +1,27 @@ +{ + "infer_steps": 30, + "sample_guide_scale": 1.0, + "sample_shift": 10.0, + "target_height": 480, + "target_width": 832, + "target_video_length": 61, + "target_fps": 10.0, + "enable_cfg": true, + "action_mode": "inverse_dynamics", + "domain_name": "av", + "view_point": "ego_view", + "action_chunk_size": 60, + "feature_caching": "NoCaching", + "rms_norm_type": "one-pass", + "attn_rms_norm_type": "one-pass", + "rope_type": "triton", + "self_attn_type": "flash_attn3", + "causal_self_attn_type": "flash_attn3", + "add_resolution_template": false, + "add_duration_template": false, + "use_system_prompt": false, + "cosmos3_meta_init": true, + "vae_cpu_offload": false, + "cpu_offload": false, + "offload_granularity": "block" +} diff --git a/configs/cosmos3/cosmos3_super_omni_i2av.json b/configs/cosmos3/cosmos3_super_omni_i2av.json new file mode 100644 index 000000000..48602096e --- /dev/null +++ b/configs/cosmos3/cosmos3_super_omni_i2av.json @@ -0,0 +1,23 @@ +{ + "infer_steps": 35, + "sample_guide_scale": 6.0, + "sample_shift": 10.0, + "target_height": 720, + "target_width": 1280, + "target_video_length": 189, + "target_fps": 24.0, + "enable_cfg": true, + "enable_sound": true, + "feature_caching": "NoCaching", + "rms_norm_type": "one-pass", + "attn_rms_norm_type": "one-pass", + "rope_type": "triton", + "self_attn_type": "flash_attn3", + "causal_self_attn_type": "flash_attn3", + "add_resolution_template": false, + "add_duration_template": false, + "cosmos3_meta_init": true, + "vae_cpu_offload": false, + "cpu_offload": false, + "offload_granularity": "block" +} diff --git a/configs/cosmos3/cosmos3_super_omni_i2v.json b/configs/cosmos3/cosmos3_super_omni_i2v.json new file mode 100644 index 000000000..3072500ad --- /dev/null +++ b/configs/cosmos3/cosmos3_super_omni_i2v.json @@ -0,0 +1,22 @@ +{ + "infer_steps": 35, + "sample_guide_scale": 6.0, + "sample_shift": 10.0, + "target_height": 720, + "target_width": 1280, + "target_video_length": 189, + "target_fps": 24.0, + "enable_cfg": true, + "feature_caching": "NoCaching", + "rms_norm_type": "one-pass", + "attn_rms_norm_type": "one-pass", + "rope_type": "triton", + "self_attn_type": "flash_attn3", + "causal_self_attn_type": "flash_attn3", + "add_resolution_template": false, + "add_duration_template": false, + "cosmos3_meta_init": true, + "vae_cpu_offload": false, + "cpu_offload": false, + "offload_granularity": "block" +} diff --git a/configs/cosmos3/cosmos3_super_omni_t2av.json b/configs/cosmos3/cosmos3_super_omni_t2av.json new file mode 100644 index 000000000..48602096e --- /dev/null +++ b/configs/cosmos3/cosmos3_super_omni_t2av.json @@ -0,0 +1,23 @@ +{ + "infer_steps": 35, + "sample_guide_scale": 6.0, + "sample_shift": 10.0, + "target_height": 720, + "target_width": 1280, + "target_video_length": 189, + "target_fps": 24.0, + "enable_cfg": true, + "enable_sound": true, + "feature_caching": "NoCaching", + "rms_norm_type": "one-pass", + "attn_rms_norm_type": "one-pass", + "rope_type": "triton", + "self_attn_type": "flash_attn3", + "causal_self_attn_type": "flash_attn3", + "add_resolution_template": false, + "add_duration_template": false, + "cosmos3_meta_init": true, + "vae_cpu_offload": false, + "cpu_offload": false, + "offload_granularity": "block" +} diff --git a/configs/cosmos3/cosmos3_super_omni_t2v.json b/configs/cosmos3/cosmos3_super_omni_t2v.json new file mode 100644 index 000000000..3072500ad --- /dev/null +++ b/configs/cosmos3/cosmos3_super_omni_t2v.json @@ -0,0 +1,22 @@ +{ + "infer_steps": 35, + "sample_guide_scale": 6.0, + "sample_shift": 10.0, + "target_height": 720, + "target_width": 1280, + "target_video_length": 189, + "target_fps": 24.0, + "enable_cfg": true, + "feature_caching": "NoCaching", + "rms_norm_type": "one-pass", + "attn_rms_norm_type": "one-pass", + "rope_type": "triton", + "self_attn_type": "flash_attn3", + "causal_self_attn_type": "flash_attn3", + "add_resolution_template": false, + "add_duration_template": false, + "cosmos3_meta_init": true, + "vae_cpu_offload": false, + "cpu_offload": false, + "offload_granularity": "block" +} diff --git a/configs/cosmos3/cosmos3_super_t2v.json b/configs/cosmos3/cosmos3_super_t2v.json new file mode 100644 index 000000000..3072500ad --- /dev/null +++ b/configs/cosmos3/cosmos3_super_t2v.json @@ -0,0 +1,22 @@ +{ + "infer_steps": 35, + "sample_guide_scale": 6.0, + "sample_shift": 10.0, + "target_height": 720, + "target_width": 1280, + "target_video_length": 189, + "target_fps": 24.0, + "enable_cfg": true, + "feature_caching": "NoCaching", + "rms_norm_type": "one-pass", + "attn_rms_norm_type": "one-pass", + "rope_type": "triton", + "self_attn_type": "flash_attn3", + "causal_self_attn_type": "flash_attn3", + "add_resolution_template": false, + "add_duration_template": false, + "cosmos3_meta_init": true, + "vae_cpu_offload": false, + "cpu_offload": false, + "offload_granularity": "block" +} diff --git a/lightx2v/infer.py b/lightx2v/infer.py index caa12bce8..6f3b11e7d 100755 --- a/lightx2v/infer.py +++ b/lightx2v/infer.py @@ -199,6 +199,11 @@ def main(): default=None, help="Directory path for lingbot camera/action control files (poses.npy, intrinsics.npy, optional action.npy).", ) + parser.add_argument("--action_mode", type=str, default=None, choices=["forward_dynamics", "inverse_dynamics", "policy"], help="Cosmos3 action mode.") + parser.add_argument("--domain_name", type=str, default=None, help="Cosmos3 action embodiment domain name.") + parser.add_argument("--view_point", type=str, default=None, help="Cosmos3 action viewpoint label.") + parser.add_argument("--action_chunk_size", type=int, default=None, help="Cosmos3 action chunk size.") + parser.add_argument("--action_chunk_index", type=int, default=None, help="Cosmos3 action chunk index when action_path contains action_chunks.") parser.add_argument( "--action_ckpt", type=str, diff --git a/lightx2v/models/audio_encoders/__init__.py b/lightx2v/models/audio_encoders/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightx2v/models/audio_encoders/hf/__init__.py b/lightx2v/models/audio_encoders/hf/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightx2v/models/audio_encoders/hf/cosmos3/__init__.py b/lightx2v/models/audio_encoders/hf/cosmos3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lightx2v/models/audio_encoders/hf/cosmos3/sound_tokenizer.py b/lightx2v/models/audio_encoders/hf/cosmos3/sound_tokenizer.py new file mode 100644 index 000000000..d43bf8a79 --- /dev/null +++ b/lightx2v/models/audio_encoders/hf/cosmos3/sound_tokenizer.py @@ -0,0 +1,138 @@ +import json +import math +import os + +import torch +import torch.nn as nn +from safetensors import safe_open +from torch.nn.utils import weight_norm + + +class Snake1d(nn.Module): + def __init__(self, hidden_dim, logscale=True): + super().__init__() + self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + self.logscale = logscale + + def forward(self, hidden_states): + shape = hidden_states.shape + alpha = self.alpha if not self.logscale else torch.exp(self.alpha) + beta = self.beta if not self.logscale else torch.exp(self.beta) + hidden_states = hidden_states.reshape(shape[0], shape[1], -1) + hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) + return hidden_states.reshape(shape) + + +class Cosmos3AudioResidualUnit(nn.Module): + def __init__(self, dimension, dilation=1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.snake1 = Snake1d(dimension) + self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)) + self.snake2 = Snake1d(dimension) + self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1)) + + def forward(self, hidden_state): + output_tensor = self.conv1(self.snake1(hidden_state)) + output_tensor = self.conv2(self.snake2(output_tensor)) + padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2 + if padding > 0: + hidden_state = hidden_state[..., padding:-padding] + return hidden_state + output_tensor + + +class Cosmos3AudioDecoderBlock(nn.Module): + def __init__(self, input_dim, output_dim, stride=1, output_padding=0): + super().__init__() + self.snake1 = Snake1d(input_dim) + self.conv_t1 = weight_norm( + nn.ConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + output_padding=output_padding, + ) + ) + self.res_unit1 = Cosmos3AudioResidualUnit(output_dim, dilation=1) + self.res_unit2 = Cosmos3AudioResidualUnit(output_dim, dilation=3) + self.res_unit3 = Cosmos3AudioResidualUnit(output_dim, dilation=9) + + def forward(self, hidden_state): + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv_t1(hidden_state) + hidden_state = self.res_unit1(hidden_state) + hidden_state = self.res_unit2(hidden_state) + hidden_state = self.res_unit3(hidden_state) + return hidden_state + + +class Cosmos3AudioDecoder(nn.Module): + def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples): + super().__init__() + channel_multiples = [1] + list(channel_multiples) + self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3)) + block = [] + for stride_index, stride in enumerate(upsampling_ratios): + block.append( + Cosmos3AudioDecoderBlock( + input_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index], + output_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index - 1], + stride=stride, + output_padding=stride % 2, + ) + ) + self.block = nn.ModuleList(block) + self.snake1 = Snake1d(channels) + self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False)) + + def forward(self, hidden_state): + hidden_state = self.conv1(hidden_state) + for layer in self.block: + hidden_state = layer(hidden_state) + hidden_state = self.snake1(hidden_state) + return self.conv2(hidden_state) + + +class Cosmos3SoundTokenizer(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.sampling_rate = int(config.get("sampling_rate", 48000)) + self.hop_size = int(config.get("hop_size") or math.prod(config.get("dec_strides", [2, 4, 5, 6, 8]))) + self.decoder = Cosmos3AudioDecoder( + channels=int(config.get("dec_dim", 320)), + input_channels=int(config.get("vocoder_input_dim", 64)), + audio_channels=int(config.get("dec_out_channels", 2)), + upsampling_ratios=list(reversed(config.get("dec_strides", [2, 4, 5, 6, 8]))), + channel_multiples=list(config.get("dec_c_mults", [1, 2, 4, 8, 16])), + ) + + @classmethod + def from_pretrained(cls, model_path, device, dtype): + tokenizer_path = os.path.join(model_path, "sound_tokenizer") + with open(os.path.join(tokenizer_path, "config.json"), "r") as f: + config = json.load(f) + model = cls(config) + weight_path = os.path.join(tokenizer_path, "diffusion_pytorch_model.safetensors") + state_dict = {} + with safe_open(weight_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("decoder."): + state_dict[key] = f.get_tensor(key) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + unexpected = [key for key in unexpected if key.startswith("decoder.")] + missing = [key for key in missing if key.startswith("decoder.")] + if missing or unexpected: + raise RuntimeError(f"Failed to load Cosmos3 sound decoder, missing={missing}, unexpected={unexpected}") + return model.to(device=device, dtype=dtype).eval() + + @torch.no_grad() + def decode(self, latents): + squeeze = latents.ndim == 2 + if squeeze: + latents = latents.unsqueeze(0) + audio = self.decoder(latents).clamp(-1.0, 1.0) + return audio.squeeze(0) if squeeze else audio diff --git a/lightx2v/models/networks/cosmos3/infer/module_io.py b/lightx2v/models/networks/cosmos3/infer/module_io.py index 0fd21da0b..89ec3ece0 100644 --- a/lightx2v/models/networks/cosmos3/infer/module_io.py +++ b/lightx2v/models/networks/cosmos3/infer/module_io.py @@ -12,6 +12,14 @@ class Cosmos3PreInferModuleOutput: vision_token_shapes: list vision_noisy_frame_indexes: list original_latent_shapes: list + sound_mse_loss_indexes: torch.Tensor | None = None + sound_token_shapes: list | None = None + sound_noisy_frame_indexes: list | None = None + action_mse_loss_indexes: torch.Tensor | None = None + action_token_shapes: list | None = None + action_noisy_frame_indexes: list | None = None + action_domain_ids: torch.Tensor | None = None + raw_action_dim: int | None = None seq_p_gen_len: int | None = None seq_p_gen_padding_size: int = 0 seq_p_local_gen_len: int | None = None @@ -21,3 +29,10 @@ class Cosmos3PreInferModuleOutput: class Cosmos3TransformerInferModuleOutput: und_seq: torch.Tensor gen_seq: torch.Tensor + + +@dataclass +class Cosmos3PostInferModuleOutput: + vision: torch.Tensor + sound: torch.Tensor | None = None + action: torch.Tensor | None = None diff --git a/lightx2v/models/networks/cosmos3/infer/post_infer.py b/lightx2v/models/networks/cosmos3/infer/post_infer.py index bcb4c142a..3abc695b0 100644 --- a/lightx2v/models/networks/cosmos3/infer/post_infer.py +++ b/lightx2v/models/networks/cosmos3/infer/post_infer.py @@ -1,5 +1,7 @@ import torch +from lightx2v.models.networks.cosmos3.infer.module_io import Cosmos3PostInferModuleOutput + class Cosmos3PostInfer: def __init__(self, config): @@ -46,6 +48,27 @@ def _unpatchify_and_unpack_latents( unpatchified_latents.append(output_tensor.unsqueeze(0)) return unpatchified_latents + @staticmethod + def _unpack_sound(preds_sound_packed): + return preds_sound_packed.transpose(0, 1).contiguous() + + @staticmethod + def _unpack_action(preds_action_packed, token_shapes_action, noisy_frame_indexes_action, raw_action_dim): + action_len = token_shapes_action[0][0] + noisy_frame_indexes = noisy_frame_indexes_action[0] + if len(noisy_frame_indexes) == 0: + return None + action_dim = preds_action_packed.shape[-1] if preds_action_packed.numel() > 0 else int(raw_action_dim or 0) + if action_dim == 0: + return None + output_tensor = torch.zeros( + (action_len, action_dim), + device=preds_action_packed.device, + dtype=preds_action_packed.dtype, + ) + output_tensor[noisy_frame_indexes] = preds_action_packed + return output_tensor + def infer(self, weights, transformer_out, pre_infer_out): last_hidden_state = torch.cat( [ @@ -61,4 +84,25 @@ def infer(self, weights, transformer_out, pre_infer_out): noisy_frame_indexes_vision=pre_infer_out.vision_noisy_frame_indexes, original_latent_shapes=pre_infer_out.original_latent_shapes, ) - return preds_vision[0] + preds_sound = None + if pre_infer_out.sound_mse_loss_indexes is not None: + preds_sound_packed = weights.audio_proj_out.apply(last_hidden_state[pre_infer_out.sound_mse_loss_indexes]) + preds_sound = self._unpack_sound(preds_sound_packed) + + preds_action = None + if pre_infer_out.action_mse_loss_indexes is not None and pre_infer_out.action_domain_ids is not None: + noisy_action_indexes = pre_infer_out.action_noisy_frame_indexes[0] + if len(noisy_action_indexes) > 0: + action_domain_ids = pre_infer_out.action_domain_ids[noisy_action_indexes] + preds_action_packed = weights.action_proj_out.apply( + last_hidden_state[pre_infer_out.action_mse_loss_indexes], + action_domain_ids, + ) + preds_action = self._unpack_action( + preds_action_packed, + pre_infer_out.action_token_shapes, + pre_infer_out.action_noisy_frame_indexes, + pre_infer_out.raw_action_dim, + ) + + return Cosmos3PostInferModuleOutput(vision=preds_vision[0], sound=preds_sound, action=preds_action) diff --git a/lightx2v/models/networks/cosmos3/infer/pre_infer.py b/lightx2v/models/networks/cosmos3/infer/pre_infer.py index 49b6eb852..e2ab5841a 100644 --- a/lightx2v/models/networks/cosmos3/infer/pre_infer.py +++ b/lightx2v/models/networks/cosmos3/infer/pre_infer.py @@ -45,17 +45,30 @@ def _patchify_and_pack_latents(self, latents): return torch.cat(packed_latent, dim=0), original_latent_shapes def _apply_timestep_embeds_to_noisy_tokens(self, packed_tokens, packed_timestep_embeds, noisy_frame_indexes, token_shapes): + if packed_timestep_embeds.numel() == 0: + return packed_tokens start_noisy_index = 0 flattened_noisy_frame_indexes = [] for noisy_indexes_i, token_shape_i in zip(noisy_frame_indexes, token_shapes): spatial_numel_i = math.prod(token_shape_i[1:]) - spatial_indexes_i = torch.arange(spatial_numel_i, device=packed_tokens.device) - frame_offsets = (noisy_indexes_i * spatial_numel_i).unsqueeze(-1) + spatial_indexes_i + start_noisy_index - flattened_noisy_frame_indexes.append(frame_offsets.flatten()) + if len(noisy_indexes_i) > 0: + spatial_indexes_i = torch.arange(spatial_numel_i, device=packed_tokens.device) + frame_offsets = (noisy_indexes_i * spatial_numel_i).unsqueeze(-1) + spatial_indexes_i + start_noisy_index + flattened_noisy_frame_indexes.append(frame_offsets.flatten()) start_noisy_index += token_shape_i[0] * spatial_numel_i + if not flattened_noisy_frame_indexes: + return packed_tokens flattened = torch.cat(flattened_noisy_frame_indexes, dim=0).unsqueeze(-1).expand(-1, packed_tokens.shape[1]) return packed_tokens.scatter_add(dim=0, index=flattened, src=packed_timestep_embeds) + def _embed_timestep(self, weights, timestep, length, device, dtype): + if length == 0: + return torch.empty((0, self.hidden_size), device=device, dtype=dtype) + timestep = timestep.to(device=device, dtype=torch.float32) if isinstance(timestep, torch.Tensor) else torch.tensor(float(timestep), device=device, dtype=torch.float32) + timesteps = timestep.expand(length) * self.timestep_scale + timestep_proj = get_timestep_embedding(timesteps, 256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1).to(device=device, dtype=dtype) + return weights.time_embedder_linear_2.apply(F.silu(weights.time_embedder_linear_1.apply(timestep_proj))).to(dtype) + def _prepare_text_segment(self, input_ids, device): und_len = len(input_ids) text_mrope_ids, next_mrope_offset = get_3d_mrope_ids_text_tokens( @@ -79,11 +92,7 @@ def _prepare_vision_segment(self, latents, text_segment, device, condition_frame num_vision_tokens = latent_t * patch_h * patch_w condition_frame_indexes = [] if condition_frame_indexes is None else condition_frame_indexes condition_frame_set = {int(idx) for idx in condition_frame_indexes if 0 <= int(idx) < latent_t} - noisy_frame_indexes = torch.tensor( - [idx for idx in range(latent_t) if idx not in condition_frame_set], - device=device, - dtype=torch.long, - ) + noisy_frame_indexes = torch.tensor([idx for idx in range(latent_t) if idx not in condition_frame_set], device=device, dtype=torch.long) frame_token_stride = patch_h * patch_w curr = text_segment["und_len"] mse_loss_indexes = [] @@ -112,11 +121,83 @@ def _prepare_vision_segment(self, latents, text_segment, device, condition_frame "num_noisy_vision_tokens": len(noisy_frame_indexes) * frame_token_stride, } - def infer(self, weights, input_ids, latents, timestep, condition_frame_indexes=None): + def _prepare_sound_segment(self, sound_latents, text_segment, vision_segment, device): + sound_len = int(sound_latents.shape[1]) + curr = text_segment["und_len"] + vision_segment["num_vision_tokens"] + effective_fps = self.config.get("sound_latent_fps", 25.0) if self.enable_fps_modulation else None + sound_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( + grid_t=sound_len, + grid_h=1, + grid_w=1, + temporal_offset=text_segment["vision_start_temporal_offset"], + reset_spatial_indices=self.reset_spatial_ids, + fps=effective_fps, + base_fps=self.base_fps, + temporal_compression_factor=1, + ) + indexes = torch.arange(sound_len, device=device, dtype=torch.long) + return { + "sound_token_shapes": [(sound_len, 1, 1)], + "sound_sequence_indexes": torch.arange(curr, curr + sound_len, dtype=torch.long, device=device), + "sound_mse_loss_indexes": torch.arange(curr, curr + sound_len, dtype=torch.long, device=device), + "sound_noisy_frame_indexes": [indexes], + "sound_mrope_ids": sound_mrope_ids.to(device), + "sound_len": sound_len, + } + + def _prepare_action_segment(self, action_latents, text_segment, vision_segment, sound_segment, device, condition_frame_indexes=None): + action_len = int(action_latents.shape[0]) + condition_frame_indexes = [] if condition_frame_indexes is None else condition_frame_indexes + cond_set = {int(idx) for idx in condition_frame_indexes if 0 <= int(idx) < action_len} + noisy_frame_indexes = torch.tensor([idx for idx in range(action_len) if idx not in cond_set], device=device, dtype=torch.long) + curr = text_segment["und_len"] + vision_segment["num_vision_tokens"] + sound_segment.get("sound_len", 0) + effective_fps = self.config.get("target_fps", 24.0) if self.enable_fps_modulation else None + action_mrope_ids, _ = get_3d_mrope_ids_vae_tokens( + grid_t=action_len, + grid_h=1, + grid_w=1, + temporal_offset=text_segment["vision_start_temporal_offset"], + reset_spatial_indices=self.reset_spatial_ids, + fps=effective_fps, + base_fps=self.base_fps, + temporal_compression_factor=1, + base_temporal_compression_factor=self.config.get("vae_scale_factor_temporal", 4), + start_frame_offset=1, + ) + sequence_indexes = torch.arange(curr, curr + action_len, dtype=torch.long, device=device) + return { + "action_token_shapes": [(action_len, 1, 1)], + "action_sequence_indexes": sequence_indexes, + "action_mse_loss_indexes": sequence_indexes[noisy_frame_indexes], + "action_noisy_frame_indexes": [noisy_frame_indexes], + "action_mrope_ids": action_mrope_ids.to(device), + "action_len": action_len, + "num_noisy_action_tokens": len(noisy_frame_indexes), + } + + def infer( + self, + weights, + input_ids, + latents, + timestep, + condition_frame_indexes=None, + sound_latents=None, + action_latents=None, + action_domain_id=None, + action_condition_frame_indexes=None, + raw_action_dim=None, + ): device = latents.device text_segment = self._prepare_text_segment(input_ids, device) vision_segment = self._prepare_vision_segment(latents, text_segment, device, condition_frame_indexes=condition_frame_indexes) - sequence_length = text_segment["und_len"] + vision_segment["num_vision_tokens"] + sound_segment = {} + action_segment = {} + if sound_latents is not None: + sound_segment = self._prepare_sound_segment(sound_latents, text_segment, vision_segment, device) + if action_latents is not None: + action_segment = self._prepare_action_segment(action_latents, text_segment, vision_segment, sound_segment, device, condition_frame_indexes=action_condition_frame_indexes) + sequence_length = text_segment["und_len"] + vision_segment["num_vision_tokens"] + sound_segment.get("sound_len", 0) + action_segment.get("action_len", 0) packed_text_embedding = weights.embed_tokens.apply(text_segment["input_ids"]) target_dtype = packed_text_embedding.dtype @@ -125,16 +206,7 @@ def infer(self, weights, input_ids, latents, timestep, condition_frame_indexes=N packed_tokens_vision, original_latent_shapes = self._patchify_and_pack_latents(latents.to(dtype=target_dtype)) packed_tokens_vision = weights.proj_in.apply(packed_tokens_vision) - timestep = timestep.to(device=device, dtype=torch.float32) if isinstance(timestep, torch.Tensor) else torch.tensor(float(timestep), device=device, dtype=torch.float32) - vision_timesteps = timestep.expand(vision_segment["num_noisy_vision_tokens"]) * self.timestep_scale - timestep_proj = get_timestep_embedding( - vision_timesteps, - 256, - flip_sin_to_cos=True, - downscale_freq_shift=0, - scale=1, - ).to(device=packed_tokens_vision.device, dtype=packed_tokens_vision.dtype) - packed_timestep_embeds_vision = weights.time_embedder_linear_2.apply(F.silu(weights.time_embedder_linear_1.apply(timestep_proj))).to(target_dtype) + packed_timestep_embeds_vision = self._embed_timestep(weights, timestep, vision_segment["num_noisy_vision_tokens"], packed_tokens_vision.device, packed_tokens_vision.dtype) packed_tokens_vision = self._apply_timestep_embeds_to_noisy_tokens( packed_tokens_vision, packed_timestep_embeds_vision, @@ -143,7 +215,40 @@ def infer(self, weights, input_ids, latents, timestep, condition_frame_indexes=N ) hidden_states[vision_segment["vision_sequence_indexes"]] = packed_tokens_vision - position_ids = torch.cat([text_segment["text_mrope_ids"], vision_segment["vision_mrope_ids"]], dim=1) + if sound_latents is not None: + packed_tokens_sound = sound_latents[:, : sound_segment["sound_len"]].permute(1, 0).to(dtype=target_dtype) + packed_tokens_sound = weights.audio_proj_in.apply(packed_tokens_sound) + weights.audio_modality_embed.tensor.to(device=device, dtype=target_dtype) + packed_timestep_embeds_sound = self._embed_timestep(weights, timestep, sound_segment["sound_len"], packed_tokens_sound.device, packed_tokens_sound.dtype) + packed_tokens_sound = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens_sound, + packed_timestep_embeds_sound, + sound_segment["sound_noisy_frame_indexes"], + sound_segment["sound_token_shapes"], + ) + hidden_states[sound_segment["sound_sequence_indexes"]] = packed_tokens_sound + + action_domain_ids = None + if action_latents is not None: + action_domain_id = torch.as_tensor(action_domain_id, device=device, dtype=torch.long).reshape(1) + action_domain_ids = action_domain_id.expand(action_segment["action_len"]) + packed_tokens_action = action_latents[: action_segment["action_len"]].to(dtype=target_dtype) + packed_tokens_action = weights.action_proj_in.apply(packed_tokens_action, action_domain_ids) + packed_tokens_action = packed_tokens_action + weights.action_modality_embed.tensor.to(device=device, dtype=target_dtype) + packed_timestep_embeds_action = self._embed_timestep(weights, timestep, action_segment["num_noisy_action_tokens"], packed_tokens_action.device, packed_tokens_action.dtype) + packed_tokens_action = self._apply_timestep_embeds_to_noisy_tokens( + packed_tokens_action, + packed_timestep_embeds_action, + action_segment["action_noisy_frame_indexes"], + action_segment["action_token_shapes"], + ) + hidden_states[action_segment["action_sequence_indexes"]] = packed_tokens_action + + mrope_segments = [text_segment["text_mrope_ids"], vision_segment["vision_mrope_ids"]] + if sound_segment: + mrope_segments.append(sound_segment["sound_mrope_ids"]) + if action_segment: + mrope_segments.append(action_segment["action_mrope_ids"]) + position_ids = torch.cat(mrope_segments, dim=1) return Cosmos3PreInferModuleOutput( hidden_states=hidden_states, und_len=text_segment["und_len"], @@ -152,4 +257,12 @@ def infer(self, weights, input_ids, latents, timestep, condition_frame_indexes=N vision_token_shapes=vision_segment["vision_token_shapes"], vision_noisy_frame_indexes=vision_segment["vision_noisy_frame_indexes"], original_latent_shapes=original_latent_shapes, + sound_mse_loss_indexes=sound_segment.get("sound_mse_loss_indexes"), + sound_token_shapes=sound_segment.get("sound_token_shapes"), + sound_noisy_frame_indexes=sound_segment.get("sound_noisy_frame_indexes"), + action_mse_loss_indexes=action_segment.get("action_mse_loss_indexes"), + action_token_shapes=action_segment.get("action_token_shapes"), + action_noisy_frame_indexes=action_segment.get("action_noisy_frame_indexes"), + action_domain_ids=action_domain_ids, + raw_action_dim=raw_action_dim, ) diff --git a/lightx2v/models/networks/cosmos3/infer/utils.py b/lightx2v/models/networks/cosmos3/infer/utils.py index 2b14ab3f4..0e2e8b923 100644 --- a/lightx2v/models/networks/cosmos3/infer/utils.py +++ b/lightx2v/models/networks/cosmos3/infer/utils.py @@ -155,16 +155,19 @@ def get_3d_mrope_ids_vae_tokens( fps: float | None = None, base_fps: float = 24.0, temporal_compression_factor: int = 4, + base_temporal_compression_factor: int | None = None, + start_frame_offset: int = 0, ): fps_modulation_enabled = fps is not None and grid_t > 1 + effective_base_tcf = base_temporal_compression_factor if base_temporal_compression_factor is not None else temporal_compression_factor if fps_modulation_enabled: tps = fps / temporal_compression_factor - base_tps = base_fps / temporal_compression_factor + base_tps = base_fps / effective_base_tcf frame_indices = torch.arange(grid_t, dtype=torch.float32) - t_index = (frame_indices / tps * base_tps + temporal_offset).view(-1, 1) + t_index = ((frame_indices + start_frame_offset) / tps * base_tps + temporal_offset).view(-1, 1) t_index = t_index.expand(-1, grid_h * grid_w).flatten() else: - t_index = torch.arange(grid_t, dtype=torch.long).view(-1, 1).expand(-1, grid_h * grid_w).flatten() + int(temporal_offset) + t_index = torch.arange(grid_t, dtype=torch.long).view(-1, 1).expand(-1, grid_h * grid_w).flatten() + int(temporal_offset) + start_frame_offset h_index = torch.arange(grid_h, dtype=torch.long).view(1, -1, 1).expand(grid_t, -1, grid_w).flatten() w_index = torch.arange(grid_w, dtype=torch.long).view(1, 1, -1).expand(grid_t, grid_h, -1).flatten() diff --git a/lightx2v/models/networks/cosmos3/model.py b/lightx2v/models/networks/cosmos3/model.py index 8abea1be4..3dfa34bcf 100644 --- a/lightx2v/models/networks/cosmos3/model.py +++ b/lightx2v/models/networks/cosmos3/model.py @@ -3,7 +3,7 @@ from torch.nn import functional as F from lightx2v.models.networks.base_model import BaseTransformerModel -from lightx2v.models.networks.cosmos3.infer.module_io import Cosmos3TransformerInferModuleOutput +from lightx2v.models.networks.cosmos3.infer.module_io import Cosmos3PostInferModuleOutput, Cosmos3TransformerInferModuleOutput from lightx2v.models.networks.cosmos3.infer.offload.transformer_infer import Cosmos3OffloadTransformerInfer from lightx2v.models.networks.cosmos3.infer.post_infer import Cosmos3PostInfer from lightx2v.models.networks.cosmos3.infer.pre_infer import Cosmos3PreInfer @@ -81,6 +81,44 @@ def _seq_parallel_post_process(self, transformer_out, pre_infer_out): gen_seq = torch.cat(gathered_gen_seq, dim=0)[: pre_infer_out.seq_p_gen_len] return Cosmos3TransformerInferModuleOutput(und_seq=transformer_out.und_seq, gen_seq=gen_seq) + def _combine_cfg_output(self, cond, uncond): + guide = self.scheduler.sample_guide_scale + return Cosmos3PostInferModuleOutput( + vision=uncond.vision + guide * (cond.vision - uncond.vision), + sound=None if cond.sound is None else uncond.sound + guide * (cond.sound - uncond.sound), + action=None if cond.action is None else uncond.action + guide * (cond.action - uncond.action), + ) + + @staticmethod + def _detach_cfg_output(output): + return Cosmos3PostInferModuleOutput( + vision=output.vision, + sound=output.sound, + action=output.action, + ) + + def _set_scheduler_noise_pred(self, output): + self.scheduler.noise_pred = output.vision + self.scheduler.noise_pred_sound = output.sound + self.scheduler.noise_pred_action = output.action + + @staticmethod + def _gather_optional_tensor(tensor, group): + if tensor is None: + return [None, None] + gathered = [torch.zeros_like(tensor) for _ in range(2)] + dist.all_gather(gathered, tensor, group=group) + return gathered + + def _gather_cfg_parallel_output(self, output, group): + vision_list = [torch.zeros_like(output.vision) for _ in range(2)] + dist.all_gather(vision_list, output.vision, group=group) + sound_list = self._gather_optional_tensor(output.sound, group) + action_list = self._gather_optional_tensor(output.action, group) + cond = Cosmos3PostInferModuleOutput(vision=vision_list[0], sound=sound_list[0], action=action_list[0]) + uncond = Cosmos3PostInferModuleOutput(vision=vision_list[1], sound=sound_list[1], action=action_list[1]) + return cond, uncond + @torch.no_grad() def _infer_cond_uncond(self, input_ids): latents = self.scheduler.latents @@ -91,6 +129,11 @@ def _infer_cond_uncond(self, input_ids): latents=latents, timestep=timestep, condition_frame_indexes=getattr(self.scheduler, "vision_condition_frame_indexes", None), + sound_latents=getattr(self.scheduler, "sound_latents", None), + action_latents=getattr(self.scheduler, "action_latents", None), + action_domain_id=getattr(self.scheduler, "action_domain_id", None), + action_condition_frame_indexes=getattr(self.scheduler, "action_condition_frame_indexes", None), + raw_action_dim=getattr(self.scheduler, "raw_action_dim", None), ) if self.config["seq_parallel"]: pre_infer_out = self._seq_parallel_pre_process(pre_infer_out) @@ -118,18 +161,16 @@ def infer(self, inputs): assert dist.get_world_size(cfg_p_group) == 2, "cfg_p_world_size must be equal to 2" cfg_p_rank = dist.get_rank(cfg_p_group) input_ids = text_encoder_output["cond_input_ids"] if cfg_p_rank == 0 else text_encoder_output["uncond_input_ids"] - noise_pred = self._infer_cond_uncond(input_ids) - noise_pred_list = [torch.zeros_like(noise_pred) for _ in range(2)] - dist.all_gather(noise_pred_list, noise_pred, group=cfg_p_group) - cond, uncond = noise_pred_list[0], noise_pred_list[1] - self.scheduler.noise_pred = uncond + self.scheduler.sample_guide_scale * (cond - uncond) + output = self._infer_cond_uncond(input_ids) + cond, uncond = self._gather_cfg_parallel_output(output, cfg_p_group) + self._set_scheduler_noise_pred(self._combine_cfg_output(cond, uncond)) elif do_cfg: cond = self._infer_cond_uncond(text_encoder_output["cond_input_ids"]) uncond = self._infer_cond_uncond(text_encoder_output["uncond_input_ids"]) - self.scheduler.noise_pred = uncond + self.scheduler.sample_guide_scale * (cond - uncond) + self._set_scheduler_noise_pred(self._combine_cfg_output(cond, uncond)) else: cond = self._infer_cond_uncond(text_encoder_output["cond_input_ids"]) - self.scheduler.noise_pred = cond + self._set_scheduler_noise_pred(self._detach_cfg_output(cond)) if self.cpu_offload: self.pre_weight.to_cpu() diff --git a/lightx2v/models/networks/cosmos3/weights/post_weights.py b/lightx2v/models/networks/cosmos3/weights/post_weights.py index 43a5857a4..3f886ae3d 100644 --- a/lightx2v/models/networks/cosmos3/weights/post_weights.py +++ b/lightx2v/models/networks/cosmos3/weights/post_weights.py @@ -1,4 +1,5 @@ from lightx2v.common.modules.weight_module import WeightModule +from lightx2v.models.networks.cosmos3.weights.pre_weights import Cosmos3DomainAwareLinearWeights from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, RMS_WEIGHT_REGISTER @@ -11,3 +12,14 @@ def __init__(self, config): self.add_module("norm", RMS_WEIGHT_REGISTER[rms_norm_type]("norm.weight", eps=eps)) self.add_module("norm_moe_gen", RMS_WEIGHT_REGISTER[rms_norm_type]("norm_moe_gen.weight", eps=eps)) self.add_module("proj_out", MM_WEIGHT_REGISTER["Default"]("proj_out.weight", "proj_out.bias")) + if config.get("sound_gen", False): + self.add_module("audio_proj_out", MM_WEIGHT_REGISTER["Default"]("audio_proj_out.weight", "audio_proj_out.bias")) + if config.get("action_gen", False): + self.add_module( + "action_proj_out", + Cosmos3DomainAwareLinearWeights( + "action_proj_out", + input_size=config["hidden_size"], + output_size=config.get("action_dim", config.get("max_action_dim", 64)), + ), + ) diff --git a/lightx2v/models/networks/cosmos3/weights/pre_weights.py b/lightx2v/models/networks/cosmos3/weights/pre_weights.py index 68d7f632b..613f61f96 100644 --- a/lightx2v/models/networks/cosmos3/weights/pre_weights.py +++ b/lightx2v/models/networks/cosmos3/weights/pre_weights.py @@ -1,5 +1,24 @@ +import torch + from lightx2v.common.modules.weight_module import WeightModule -from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER, MM_WEIGHT_REGISTER +from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER, MM_WEIGHT_REGISTER, TENSOR_REGISTER + + +class Cosmos3DomainAwareLinearWeights(WeightModule): + def __init__(self, prefix, input_size, output_size): + super().__init__() + self.input_size = int(input_size) + self.output_size = int(output_size) + self.register_parameter("fc", TENSOR_REGISTER["Default"](f"{prefix}.fc.weight")) + self.register_parameter("bias", TENSOR_REGISTER["Default"](f"{prefix}.bias.weight")) + + def apply(self, x, domain_ids): + domain_ids = domain_ids.to(device=x.device, dtype=torch.long).reshape(-1) + if x.shape[0] != domain_ids.shape[0]: + raise ValueError(f"Cosmos3 action domain ids must match token count: {domain_ids.shape[0]} vs {x.shape[0]}") + weight = self.fc.tensor.to(device=x.device, dtype=x.dtype)[domain_ids].view(-1, self.input_size, self.output_size) + bias = self.bias.tensor.to(device=x.device, dtype=x.dtype)[domain_ids].view(-1, self.output_size) + return torch.bmm(x.unsqueeze(1), weight).squeeze(1) + bias class Cosmos3PreWeights(WeightModule): @@ -16,3 +35,16 @@ def __init__(self, config): "time_embedder_linear_2", MM_WEIGHT_REGISTER["Default"]("time_embedder.linear_2.weight", "time_embedder.linear_2.bias"), ) + if config.get("sound_gen", False): + self.add_module("audio_proj_in", MM_WEIGHT_REGISTER["Default"]("audio_proj_in.weight", "audio_proj_in.bias")) + self.register_parameter("audio_modality_embed", TENSOR_REGISTER["Default"]("audio_modality_embed")) + if config.get("action_gen", False): + self.add_module( + "action_proj_in", + Cosmos3DomainAwareLinearWeights( + "action_proj_in", + input_size=config.get("action_dim", config.get("max_action_dim", 64)), + output_size=config["hidden_size"], + ), + ) + self.register_parameter("action_modality_embed", TENSOR_REGISTER["Default"]("action_modality_embed")) diff --git a/lightx2v/models/runners/cosmos3/cosmos3_runner.py b/lightx2v/models/runners/cosmos3/cosmos3_runner.py index 24f26464d..c6670abe7 100644 --- a/lightx2v/models/runners/cosmos3/cosmos3_runner.py +++ b/lightx2v/models/runners/cosmos3/cosmos3_runner.py @@ -1,6 +1,12 @@ import gc +import json import os +import subprocess +import tempfile +import wave +import imageio +import imageio_ffmpeg as ffmpeg import numpy as np import torch import torch.distributed as dist @@ -8,6 +14,7 @@ from loguru import logger from transformers import AutoTokenizer +from lightx2v.models.audio_encoders.hf.cosmos3.sound_tokenizer import Cosmos3SoundTokenizer from lightx2v.models.networks.cosmos3.model import Cosmos3TransformerModel from lightx2v.models.runners.default_runner import DefaultRunner from lightx2v.models.schedulers.cosmos3.scheduler import Cosmos3Scheduler @@ -23,10 +30,55 @@ _SYSTEM_PROMPT_IMAGE = "You are a helpful assistant who will generate images from a give prompt." _SYSTEM_PROMPT_VIDEO = "You are a helpful assistant who will generate videos from a give prompt." +_ACTION_VIEWPOINT_TEMPLATES = { + "ego_view": "This video is captured from a first-person perspective looking at the scene.", + "third_person_view": "This video is captured from a third-person perspective looking towards the agent from the front.", + "wrist_view": "This video is captured from a wrist-mounted camera.", + "concat_view": "This video contains concatenated views from multiple camera perspectives.", +} + +_EMBODIMENT_TO_DOMAIN_ID = { + "no_action": 0, + "av": 1, + "camera_pose": 2, + "hand_pose": 3, + "pusht": 4, + "libero": 5, + "umi": 6, + "bridge_orig_lerobot": 7, + "droid_lerobot": 8, + "robomind-franka": 8, + "galbot": 9, + "robomind-franka-dual": 12, + "robomind-ur": 13, + "agibotworld": 15, + "agibot_gear_gripper": 15, + "agibot_gear_gripper_ext": 15, + "fractal": 20, +} + +_EMBODIMENT_TO_RAW_ACTION_DIM = { + "av": 9, + "camera_pose": 9, + "pusht": 2, + "umi": 10, + "bridge_orig_lerobot": 10, + "droid_lerobot": 10, + "robomind-franka": 10, + "robomind-franka-dual": 20, + "robomind-ur": 10, + "galbot": 30, + "agibotworld": 29, + "agibot_gear_gripper": 29, + "agibot_gear_gripper_ext": 29, + "fractal": 10, + "hand_pose": 57, +} + @RUNNER_REGISTER("cosmos3") class Cosmos3Runner(DefaultRunner): - model_cpu_offload_seq = "transformer->vae" + model_cpu_offload_seq = "transformer->vae->sound_tokenizer" _callback_tensor_inputs = ["latents"] @ProfilingContext4DebugL2("Load models") @@ -55,6 +107,10 @@ def load_image_encoder(self): def load_vae(self): return Cosmos3WanVAE(self.config) + def load_sound_tokenizer(self): + logger.info("Loading Cosmos3 sound tokenizer") + return Cosmos3SoundTokenizer.from_pretrained(self.config["model_path"], device=AI_DEVICE, dtype=GET_DTYPE()) + def init_modules(self): logger.info("Initializing Cosmos3 runner modules...") if not self.config.get("lazy_load", False) and not self.config.get("unload_modules", False): @@ -63,8 +119,12 @@ def init_modules(self): assert self.config.get("cpu_offload", False) if hasattr(self, "model") and self.model is not None: self.model.set_scheduler(self.scheduler) - if self.config["task"] not in ("t2i", "i2v"): - raise NotImplementedError(f"Cosmos3Runner currently supports tasks t2i/i2v, got {self.config['task']}") + if self.config["task"] not in ("t2i", "t2v", "i2v", "t2av", "i2av", "i2va", "v2av"): + raise NotImplementedError(f"Cosmos3Runner currently supports tasks t2i/t2v/i2v/t2av/i2av/i2va/v2av, got {self.config['task']}") + if (self.config.get("enable_sound", False) or self.config["task"] in ("t2av", "i2av")) and not self.config.get("sound_gen", False): + raise ValueError("Cosmos3 sound generation requires a checkpoint with sound_gen=True.") + if (self.config.get("action_mode", "") or self.config["task"] in ("i2va", "v2av")) and not self.config.get("action_gen", False): + raise ValueError("Cosmos3 action generation requires a checkpoint with action_gen=True.") self.run_input_encoder = self._run_input_encoder_local self.run_dit = self._run_dit_local self.config.lock() @@ -79,11 +139,98 @@ def _append_prompt_template(self, base: str, addition: str) -> str: base = (base or "").rstrip(".") return f"{base}. {addition}" if base else addition - def _tokenize_chat(self, text: str, is_image: bool): - conversations = [ - {"role": "system", "content": _SYSTEM_PROMPT_IMAGE if is_image else _SYSTEM_PROMPT_VIDEO}, - {"role": "user", "content": text}, - ] + def _resolve_prompt_text(self, text: str): + if not isinstance(text, str): + return text + if not os.path.isfile(text): + return text + if text.endswith(".json"): + with open(text, "r") as f: + return json.dumps(json.load(f)) + with open(text, "r") as f: + return f.read().strip() + + def _load_action_spec(self): + if hasattr(self, "_action_spec"): + return self._action_spec + action_path = getattr(self.input_info, "action_path", None) or self.config.get("action_path", None) + self._action_spec = {} + if action_path and os.path.isfile(action_path): + with open(action_path, "r") as f: + spec = json.load(f) + if isinstance(spec, dict): + self._action_spec = spec + else: + self._action_spec = {"raw_actions": spec} + return self._action_spec + + def _get_action_mode(self): + mode = getattr(self.input_info, "action_mode", None) or self.config.get("action_mode", "") + if self.config["task"] in ("i2va", "v2av") and not mode: + mode = "inverse_dynamics" if self.config["task"] == "v2av" else "forward_dynamics" + return mode or None + + def _get_action_value(self, key, default=None): + value = getattr(self.input_info, key, None) + if value not in (None, ""): + return value + spec = self._load_action_spec() + if key in spec and spec[key] not in (None, ""): + return spec[key] + return self.config.get(key, default) + + def _get_target_video_length(self): + if self._get_action_mode(): + return int(getattr(self.input_info, "target_video_length", 0) or self.config.get("target_video_length", 1)) + input_frames = int(getattr(self.input_info, "target_video_length", 0) or 0) + if input_frames and input_frames != 81: + return input_frames + return int(self.config.get("target_video_length", input_frames or 1)) + + def _prepare_action_context(self): + if hasattr(self, "_action_spec"): + del self._action_spec + if not self._get_action_mode(): + return + spec = self._load_action_spec() + if not getattr(self.input_info, "prompt", "") and spec.get("prompt"): + self.input_info.prompt = spec["prompt"] + chunk_size = int(self._get_action_value("action_chunk_size", self.config.get("action_chunk_size", 16))) + self.input_info.action_chunk_size = chunk_size + self.input_info.target_video_length = chunk_size + 1 + if spec.get("fps") and "target_fps" not in self.config: + self.input_info.target_fps = float(spec["fps"]) + + @staticmethod + def _build_action_json_prompt(description, view_point, num_frames, fps, height, width): + duration_seconds = num_frames / fps if fps > 0 else 0.0 + duration = int(duration_seconds) if duration_seconds >= 0 and np.isfinite(duration_seconds) else 0 + action_end = round(duration_seconds) if duration_seconds >= 0 and np.isfinite(duration_seconds) else 0 + minutes, seconds = divmod(action_end, 60) + desc = description.strip() + if desc and not desc.endswith((".", "!", "?")): + desc = f"{desc}." + prompt = {} + framing = _ACTION_VIEWPOINT_TEMPLATES.get(view_point) + if framing: + prompt["cinematography"] = {"framing": framing} + ratio = width / height if height > 0 else 1.0 + aspect_ratio = min( + ("1,1", "4,3", "3,4", "16,9", "9,16"), + key=lambda r: abs(int(r.split(",")[0]) / int(r.split(",")[1]) - ratio), + ) + prompt["actions"] = [{"time": f"0:00-{minutes}:{seconds:02d}", "description": desc}] + prompt["duration"] = f"{duration}s" + prompt["fps"] = float(fps) + prompt["resolution"] = {"H": int(height), "W": int(width)} + prompt["aspect_ratio"] = aspect_ratio + return json.dumps(prompt) + + def _tokenize_chat(self, text: str, is_image: bool, use_system_prompt=True): + conversations = [] + if use_system_prompt: + conversations.append({"role": "system", "content": _SYSTEM_PROMPT_IMAGE if is_image else _SYSTEM_PROMPT_VIDEO}) + conversations.append({"role": "user", "content": text}) kwargs = { "tokenize": True, "add_generation_prompt": True, @@ -111,19 +258,31 @@ def _tokenize_chat(self, text: str, is_image: bool): return list(input_ids) def tokenize_prompt(self, prompt, negative_prompt=None): + prompt = self._resolve_prompt_text(prompt) + negative_prompt = self._resolve_prompt_text(negative_prompt) if negative_prompt is not None else None height = int(self.input_info.auto_height) width = int(self.input_info.auto_width) - num_frames = int(self.config.get("target_video_length", 1)) + num_frames = self._get_target_video_length() fps = float(self.config.get("target_fps", 24.0)) is_image = num_frames == 1 negative_prompt = "" if negative_prompt is None else negative_prompt + action_mode = self._get_action_mode() cond_text = prompt uncond_text = negative_prompt - if not is_image and self.config.get("add_duration_template", True): + if action_mode: + cond_text = self._build_action_json_prompt( + prompt, + view_point=self._get_action_value("view_point", "ego_view"), + num_frames=num_frames, + fps=fps, + height=height, + width=width, + ) + elif not is_image and self.config.get("add_duration_template", True): cond_text = self._append_prompt_template(cond_text, f"The video is {num_frames / fps:.1f} seconds long and is of {fps:.0f} FPS.") uncond_text = self._append_prompt_template(uncond_text, f"The video is not {num_frames / fps:.1f} seconds long and is not of {fps:.0f} FPS.") - if self.config.get("add_resolution_template", True): + if not action_mode and self.config.get("add_resolution_template", True): if is_image: cond_text = self._append_prompt_template(cond_text, f"This image is of {height}x{width} resolution.") uncond_text = self._append_prompt_template(uncond_text, f"This image is not of {height}x{width} resolution.") @@ -136,8 +295,9 @@ def tokenize_prompt(self, prompt, negative_prompt=None): if eos_token_id is None or vision_start_id is None: raise ValueError("Cosmos3 tokenizer must provide eos_token_id and <|vision_start|>.") - cond_input_ids = self._tokenize_chat(cond_text, is_image=is_image) + [eos_token_id, vision_start_id] - uncond_input_ids = self._tokenize_chat(uncond_text, is_image=is_image) + [eos_token_id, vision_start_id] + use_system_prompt = self.config.get("use_system_prompt", not bool(action_mode)) + cond_input_ids = self._tokenize_chat(cond_text, is_image=is_image, use_system_prompt=use_system_prompt) + [eos_token_id, vision_start_id] + uncond_input_ids = self._tokenize_chat(uncond_text, is_image=is_image, use_system_prompt=use_system_prompt) + [eos_token_id, vision_start_id] return cond_input_ids, uncond_input_ids @ProfilingContext4DebugL2("Run Encoders") @@ -176,7 +336,7 @@ def set_target_shape(self): height, width = rounded_height, rounded_width latent_channels = int(self.config.get("latent_channel", 48)) - pixel_frames = int(self.config.get("target_video_length", 1)) + pixel_frames = self._get_target_video_length() latent_frames = (pixel_frames - 1) // temporal_scale + 1 self.input_info.auto_height = height self.input_info.auto_width = width @@ -201,7 +361,7 @@ def _load_i2v_condition_frame(self): @ProfilingContext4DebugL2("Prepare i2v condition") def _prepare_i2v_condition_latents(self): - if self.config["task"] != "i2v": + if self.config["task"] not in ("i2v", "i2av"): return if hasattr(self.input_info, "vision_condition_latents") and self.input_info.vision_condition_latents is not None: return @@ -209,7 +369,7 @@ def _prepare_i2v_condition_latents(self): if loaded_vae_here: self.vae = self.load_vae() frame = self._load_i2v_condition_frame() - num_frames = int(self.config.get("target_video_length", 189)) + num_frames = self._get_target_video_length() video = frame.unsqueeze(2).expand(-1, -1, num_frames, -1, -1).contiguous() condition_latents = self.vae.encode(video) self.input_info.vision_condition_latents = condition_latents @@ -220,12 +380,175 @@ def _prepare_i2v_condition_latents(self): torch_device_module.empty_cache() gc.collect() + def _frame_array_to_tensor(self, frame, height, width): + if frame.ndim == 2: + frame = np.stack([frame] * 3, axis=-1) + if frame.shape[-1] == 4: + frame = frame[..., :3] + resample = getattr(Image, "Resampling", Image).BILINEAR + image = Image.fromarray(frame.astype(np.uint8)).convert("RGB").resize((width, height), resample=resample) + frame = np.asarray(image).astype(np.float32) / 127.5 - 1.0 + return torch.from_numpy(frame).permute(2, 0, 1) + + def _load_video_tensor(self, video_path, num_frames, height, width, keep_first=True): + if not video_path or not os.path.isfile(video_path): + raise FileNotFoundError(f"Cosmos3 action video_path does not exist: {video_path}") + reader = imageio.get_reader(video_path) + frames = [] + try: + for frame in reader: + frames.append(self._frame_array_to_tensor(np.asarray(frame), height, width)) + if len(frames) >= num_frames: + break + finally: + reader.close() + if not frames: + raise ValueError(f"Cosmos3 could not read frames from video_path: {video_path}") + if keep_first: + frames = frames[:1] + while len(frames) < num_frames: + frames.append(frames[-1].clone()) + return torch.stack(frames[:num_frames], dim=1).unsqueeze(0).to(device=AI_DEVICE, dtype=GET_DTYPE()) + + def _load_image_tensor_by_path(self, image_path, height, width): + if not image_path or not os.path.isfile(image_path): + raise FileNotFoundError(f"Cosmos3 action image_path does not exist: {image_path}") + resample = getattr(Image, "Resampling", Image).BILINEAR + with Image.open(image_path) as image: + image = image.convert("RGB").resize((width, height), resample=resample) + frame = np.asarray(image).astype(np.float32) / 127.5 - 1.0 + return torch.from_numpy(frame).permute(2, 0, 1).unsqueeze(0).to(device=AI_DEVICE, dtype=GET_DTYPE()) + + def _get_action_chunk_index(self): + return int(getattr(self, "_action_chunk_index", self.config.get("action_chunk_index", 0))) + + def _get_action_num_chunks(self): + spec = self._load_action_spec() + default_num_chunks = int(spec.get("num_chunks", 1) or 1) + if "action_chunks" in spec: + default_num_chunks = len(spec["action_chunks"]) + num_chunks = int(self.config.get("action_num_chunks", default_num_chunks) or default_num_chunks) + if "action_chunks" in spec: + num_chunks = min(num_chunks, len(spec["action_chunks"])) + return max(num_chunks, 1) + + def _is_action_forward_multichunk(self): + return self._get_action_mode() == "forward_dynamics" and bool(self.config.get("action_multichunk", False)) and self._get_action_num_chunks() > 1 + + def _load_action_raw_actions(self, chunk_size, raw_action_dim): + spec = self._load_action_spec() + actions = spec.get("raw_actions") + actions_from_chunks = False + if actions is None and "action_chunks" in spec: + chunk_index = self._get_action_chunk_index() + actions = spec["action_chunks"][chunk_index] + actions_from_chunks = True + if actions is None and "raw_actions" not in spec: + action_path = getattr(self.input_info, "action_path", None) or self.config.get("action_path", None) + if action_path and os.path.isfile(action_path): + with open(action_path, "r") as f: + actions = json.load(f) + if actions is None: + raise ValueError("Cosmos3 forward_dynamics requires raw actions in --action_path or config action_path.") + actions = torch.as_tensor(actions, dtype=GET_DTYPE(), device=AI_DEVICE) + if actions.ndim == 3 and actions.shape[0] == 1: + actions = actions.squeeze(0) + if actions.ndim != 2: + raise ValueError(f"Cosmos3 raw actions must have shape [T, D], got {tuple(actions.shape)}") + if actions.shape[1] != raw_action_dim: + raise ValueError(f"Cosmos3 raw action dim mismatch for domain: expected {raw_action_dim}, got {actions.shape[1]}") + if not actions_from_chunks and actions.shape[0] > chunk_size: + start = self._get_action_chunk_index() * chunk_size + if start >= actions.shape[0]: + raise ValueError(f"Cosmos3 action_chunk_index={self._get_action_chunk_index()} is out of range for raw_actions length {actions.shape[0]}.") + actions = actions[start : start + chunk_size] + if actions.shape[0] < chunk_size: + actions = torch.cat([actions, actions[-1:].expand(chunk_size - actions.shape[0], -1)], dim=0) + return actions[:chunk_size] + + @ProfilingContext4DebugL2("Prepare action condition") + def _prepare_action_condition_latents(self): + action_mode = self._get_action_mode() + if not action_mode: + return + if hasattr(self.input_info, "action_latents") or hasattr(self.input_info, "action_latent_shape"): + return + chunk_size = int(getattr(self.input_info, "action_chunk_size", 0) or self._get_action_value("action_chunk_size", 16)) + action_dim = int(self.config.get("action_dim", self.config.get("max_action_dim", 64))) + domain_name = self._get_action_value("domain_name", None) + if not domain_name: + raise ValueError("Cosmos3 action generation requires domain_name in config or action JSON.") + if domain_name not in _EMBODIMENT_TO_DOMAIN_ID or domain_name not in _EMBODIMENT_TO_RAW_ACTION_DIM: + raise ValueError(f"Unsupported Cosmos3 action domain_name={domain_name!r}") + raw_action_dim = int(_EMBODIMENT_TO_RAW_ACTION_DIM[domain_name]) + if raw_action_dim > action_dim: + raise ValueError(f"Cosmos3 raw_action_dim={raw_action_dim} exceeds model action_dim={action_dim}") + + height = int(self.input_info.auto_height) + width = int(self.input_info.auto_width) + num_frames = chunk_size + 1 + image_path = getattr(self.input_info, "image_path", None) or self.config.get("image_path", "") + video_path = getattr(self.input_info, "video_path", None) or self.config.get("video_path", "") + + loaded_vae_here = not hasattr(self, "vae") or self.vae is None + if loaded_vae_here: + self.vae = self.load_vae() + if action_mode == "inverse_dynamics": + video = self._load_video_tensor(video_path, num_frames, height, width, keep_first=False) + elif image_path: + frame = self._load_image_tensor_by_path(image_path, height, width) + video = frame.unsqueeze(2).expand(-1, -1, num_frames, -1, -1).contiguous() + else: + video = self._load_video_tensor(video_path, num_frames, height, width, keep_first=True) + + condition_latents = self.vae.encode(video) + self.input_info.vision_condition_latents = condition_latents + if action_mode == "inverse_dynamics": + self.input_info.vision_condition_frame_indexes = list(range(condition_latents.shape[2])) + else: + self.input_info.vision_condition_frame_indexes = [0] + + self.input_info.action_domain_id = int(_EMBODIMENT_TO_DOMAIN_ID[domain_name]) + self.input_info.raw_action_dim = raw_action_dim + if action_mode == "forward_dynamics": + raw_actions = self._load_action_raw_actions(chunk_size, raw_action_dim) + if raw_action_dim < action_dim: + padding = torch.zeros((chunk_size, action_dim - raw_action_dim), device=AI_DEVICE, dtype=raw_actions.dtype) + raw_actions = torch.cat([raw_actions, padding], dim=-1) + self.input_info.action_latents = raw_actions + self.input_info.action_condition_frame_indexes = list(range(chunk_size)) + elif action_mode in ("inverse_dynamics", "policy"): + self.input_info.action_latent_shape = (chunk_size, action_dim) + self.input_info.action_condition_frame_indexes = [] + else: + raise ValueError(f"Unsupported Cosmos3 action_mode={action_mode!r}") + + del video + if loaded_vae_here and (self.config.get("lazy_load", False) or self.config.get("unload_modules", False)): + del self.vae + torch_device_module.empty_cache() + gc.collect() + + def _clear_action_condition_state(self): + for name in ( + "vision_condition_latents", + "vision_condition_frame_indexes", + "action_latents", + "action_latent_shape", + "action_condition_frame_indexes", + "action_domain_id", + "raw_action_dim", + ): + if hasattr(self.input_info, name): + delattr(self.input_info, name) + @ProfilingContext4DebugL2("Run DiT") def _run_dit_local(self, total_steps=None): - if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + if (self.config.get("lazy_load", False) or self.config.get("unload_modules", False)) and (not hasattr(self, "model") or self.model is None): self.model = self.load_transformer() self.model.set_scheduler(self.scheduler) self._prepare_i2v_condition_latents() + self._prepare_action_condition_latents() self.model.scheduler.prepare(self.input_info) if hasattr(self.input_info, "vision_condition_latents"): self.input_info.vision_condition_latents = None @@ -246,6 +569,100 @@ def run_vae_decoder(self, latents): gc.collect() return images + @ProfilingContext4DebugL1("Run Sound Decoder") + def run_sound_decoder(self, sound_latents): + if sound_latents is None: + return None + if not hasattr(self, "sound_tokenizer") or self.sound_tokenizer is None: + self.sound_tokenizer = self.load_sound_tokenizer() + decoder_dtype = next(self.sound_tokenizer.parameters()).dtype + sound = self.sound_tokenizer.decode(sound_latents.to(device=AI_DEVICE, dtype=decoder_dtype)).detach().cpu() + if self.config.get("lazy_load", False) or self.config.get("unload_modules", False): + del self.sound_tokenizer + torch_device_module.empty_cache() + gc.collect() + return sound + + def _write_wav(self, path, audio, sample_rate): + audio = audio.detach().float().cpu().clamp(-1.0, 1.0) + if audio.ndim == 1: + audio = audio.unsqueeze(0) + pcm = (audio.transpose(0, 1).numpy() * 32767.0).round().astype(np.int16) + with wave.open(path, "wb") as f: + f.setnchannels(int(audio.shape[0])) + f.setsampwidth(2) + f.setframerate(int(sample_rate)) + f.writeframes(pcm.tobytes()) + + def _mux_generated_audio(self, video_path, audio): + if audio is None: + return + sample_rate = int(self.config.get("sound_sampling_rate", 48000)) + ffmpeg_exe = ffmpeg.get_ffmpeg_exe() + target_dir = os.path.dirname(video_path) or "." + os.makedirs(target_dir, exist_ok=True) + with tempfile.TemporaryDirectory(prefix=".cosmos3_mux_", dir=target_dir) as tmp_dir: + wav_path = os.path.join(tmp_dir, "cosmos3_sound.wav") + tmp_video_path = os.path.join(tmp_dir, "cosmos3_muxed.mp4") + self._write_wav(wav_path, audio, sample_rate) + cmd = [ + ffmpeg_exe, + "-y", + "-i", + video_path, + "-i", + wav_path, + "-map", + "0:v:0", + "-map", + "1:a:0", + "-c:v", + "copy", + "-c:a", + "aac", + "-b:a", + "192k", + "-shortest", + "-f", + "mp4", + tmp_video_path, + ] + result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + if result.returncode != 0: + stderr = result.stderr.decode(errors="ignore") if result.stderr else "Unknown error" + logger.warning(f"Cosmos3 generated audio mux failed, keep silent video. Error: {stderr}") + return + os.replace(tmp_video_path, video_path) + + def _collect_action_output(self, action_latents): + if action_latents is None: + return None + action_mode = self._get_action_mode() + if action_mode not in ("inverse_dynamics", "policy"): + return None + raw_action_dim = getattr(self.input_info, "raw_action_dim", None) + action = action_latents.detach().cpu() + if raw_action_dim is not None: + action = action[:, : int(raw_action_dim)] + return action + + def _save_action_output(self, action): + if action is None: + return + save_action_path = getattr(self.input_info, "save_action_path", None) + if not save_action_path: + save_result_path = getattr(self.input_info, "save_result_path", None) + if not save_result_path: + return + root, _ = os.path.splitext(save_result_path) + save_action_path = f"{root}_action.json" + save_dir = os.path.dirname(save_action_path) + if save_dir: + os.makedirs(save_dir, exist_ok=True) + with open(save_action_path, "w") as f: + json.dump(action.tolist(), f) + logger.info(f"Action saved: {save_action_path}") + def run(self, total_steps=None): if total_steps is None: total_steps = self.model.scheduler.infer_steps @@ -261,7 +678,32 @@ def run(self, total_steps=None): self.progress_callback(((step_index + 1) / total_steps) * 100, 100) return self.model.scheduler.latents, self.model.scheduler.generator - def _save_images(self, images, input_info, log_prefix="Image saved"): + def _images_to_video_tensor(self, images): + if isinstance(images, torch.Tensor): + if images.dim() == 5: + video = images[0].permute(1, 2, 3, 0).contiguous() + elif images.dim() == 4: + video = images + else: + raise ValueError(f"Cosmos3 video tensor output must be 4D or 5D, got {tuple(images.shape)}") + return video.detach().float().cpu().clamp(0, 1) + frames = [] + for image in images: + if isinstance(image, torch.Tensor): + frame = image.detach().float().cpu() + else: + frame = torch.from_numpy(np.asarray(image).astype(np.float32) / 255.0) + frames.append(frame) + if not frames: + raise ValueError("Cosmos3 video output is empty.") + return torch.stack(frames, dim=0).float().clamp(0, 1) + + def _save_frame_tensor(self, frame, path): + frame = frame.detach().float().cpu().clamp(0, 1) + frame = (frame.numpy() * 255.0).round().astype(np.uint8) + Image.fromarray(frame).save(path) + + def _save_images(self, images, input_info, log_prefix="Image saved", sound=None): if dist.is_initialized() and dist.get_rank() != 0: return if input_info.return_result_tensor or not input_info.save_result_path: @@ -270,23 +712,15 @@ def _save_images(self, images, input_info, log_prefix="Image saved"): save_dir = os.path.dirname(save_path) if save_dir: os.makedirs(save_dir, exist_ok=True) - if self.config["task"] == "i2v": - if isinstance(images, torch.Tensor): - if images.dim() == 5: - video = images[0].permute(1, 2, 3, 0).contiguous() - elif images.dim() == 4: - video = images - else: - raise ValueError(f"Cosmos3 i2v tensor output must be 4D or 5D, got {tuple(images.shape)}") - else: - frames = [torch.from_numpy(np.asarray(image).astype(np.float32) / 255.0) for image in images] - video = torch.stack(frames, dim=0) + if self._is_video_output(): + video = self._images_to_video_tensor(images) save_to_video( video.clamp(0, 1), save_path, fps=float(self.config.get("target_fps", 24.0)), method=self.config.get("save_video_method", "ffmpeg"), ) + self._mux_generated_audio(save_path, sound) logger.info(f"Video saved: {save_path}") return image_prefix, image_suffix = os.path.splitext(save_path) @@ -301,18 +735,35 @@ def _save_images(self, images, input_info, log_prefix="Image saved"): image.save(f"{image_prefix}.{image_suffix}") logger.info(f"{log_prefix}: {image_prefix}.{image_suffix}") - def _finalize_pipeline_outputs(self, input_info, images, latents=None, generator=None): + def _finalize_pipeline_outputs(self, input_info, images, latents=None, generator=None, sound=None, action=None): if latents is not None: del latents if generator is not None: del generator torch_device_module.empty_cache() gc.collect() + output_key = "video" if self._is_video_output() else "images" if input_info.return_result_tensor: - return {"video" if self.config["task"] == "i2v" else "images": images} + outputs = {output_key: images} + if sound is not None: + outputs["sound"] = sound + if action is not None: + outputs["action"] = action + return outputs if input_info.save_result_path is not None: - return {"video" if self.config["task"] == "i2v" else "images": None} - return {"video" if self.config["task"] == "i2v" else "images": images} + outputs = {output_key: None} + if action is not None: + outputs["action"] = None + return outputs + outputs = {output_key: images} + if sound is not None: + outputs["sound"] = sound + if action is not None: + outputs["action"] = action + return outputs + + def _is_video_output(self): + return int(self.config.get("target_video_length", 1)) > 1 def end_run(self): if hasattr(self, "model") and self.model is not None: @@ -330,14 +781,70 @@ def end_run(self): torch_device_module.empty_cache() gc.collect() + @ProfilingContext4DebugL1("RUN action multichunk pipeline") + def _run_action_forward_multichunk_pipeline(self, input_info): + num_chunks = self._get_action_num_chunks() + original_image_path = getattr(input_info, "image_path", "") + save_result_path = getattr(input_info, "save_result_path", "") or "" + temp_root = os.path.dirname(save_result_path) or "." + os.makedirs(temp_root, exist_ok=True) + + stitched_videos = [] + generator = None + current_image_path = original_image_path + try: + with tempfile.TemporaryDirectory(prefix=".cosmos3_rollout_", dir=temp_root) as tmp_dir: + for chunk_index in range(num_chunks): + logger.info(f"Cosmos3 action forward rollout chunk: {chunk_index + 1} / {num_chunks}") + self._action_chunk_index = chunk_index + input_info.image_path = current_image_path + self._clear_action_condition_state() + + latents, generator = self.run_dit() + images = self.run_vae_decoder(latents) + video = self._images_to_video_tensor(images) + if video.shape[0] <= 1: + raise ValueError(f"Cosmos3 action chunk output must contain condition + generated frames, got {video.shape[0]} frame.") + + chunk_video = video if chunk_index == 0 else video[1:] + stitched_videos.append(chunk_video.cpu()) + if chunk_index + 1 < num_chunks: + next_frame_path = os.path.join(tmp_dir, f"chunk_{chunk_index + 1:05d}_last.png") + self._save_frame_tensor(video[-1], next_frame_path) + current_image_path = next_frame_path + + if hasattr(self, "model") and self.model is not None: + self.model.scheduler.clear() + del latents, images, video + torch_device_module.empty_cache() + gc.collect() + finally: + if hasattr(self, "_action_chunk_index"): + del self._action_chunk_index + input_info.image_path = original_image_path + self._clear_action_condition_state() + + video = torch.cat(stitched_videos, dim=0) + self._save_images(video, input_info, log_prefix="Video saved", sound=None) + self.end_run() + return self._finalize_pipeline_outputs(input_info, video, generator=generator) + @ProfilingContext4DebugL1("RUN pipeline") def run_pipeline(self, input_info): self.input_info = input_info + self._prepare_action_context() self.set_target_shape() self.inputs = self.run_input_encoder() logger.info(f"input_info: {self.input_info}") + if self._is_action_forward_multichunk(): + return self._run_action_forward_multichunk_pipeline(input_info) latents, generator = self.run_dit() + sound_latents = getattr(self.model.scheduler, "sound_latents", None) if hasattr(self, "model") else None + action_latents = getattr(self.model.scheduler, "action_latents", None) if hasattr(self, "model") else None + sound = self.run_sound_decoder(sound_latents) + action = self._collect_action_output(action_latents) images = self.run_vae_decoder(latents) + self._save_images(images, input_info, log_prefix="Image saved", sound=sound) + self._save_action_output(action) self.end_run() - self._save_images(images, input_info, log_prefix="Image saved") - return self._finalize_pipeline_outputs(input_info, images, latents=latents, generator=generator) + return self._finalize_pipeline_outputs(input_info, images, latents=latents, generator=generator, sound=sound, action=action) diff --git a/lightx2v/models/schedulers/cosmos3/scheduler.py b/lightx2v/models/schedulers/cosmos3/scheduler.py index 54c6ec60c..edafdd9e2 100644 --- a/lightx2v/models/schedulers/cosmos3/scheduler.py +++ b/lightx2v/models/schedulers/cosmos3/scheduler.py @@ -1,3 +1,4 @@ +import copy import json import math import os @@ -497,8 +498,12 @@ def __init__(self, config): if sample_shift is not None: self.scheduler_config["flow_shift"] = float(sample_shift) self.unipc = None + self.sound_unipc = None + self.action_unipc = None self.timesteps = None self.noise_pred = None + self.noise_pred_sound = None + self.noise_pred_action = None self.keep_latents_dtype_in_scheduler = True def _build_unipc(self): @@ -542,13 +547,47 @@ def prepare_latents(self, input_info): self.vision_condition_mask = mask else: self.latents = noise + self.sound_latents = None + if self.config.get("enable_sound", False) or self.config.get("task") in ("t2av", "i2av"): + sound_shape = getattr(input_info, "sound_latent_shape", None) or getattr(input_info, "audio_latent_shape", None) + if not sound_shape: + sound_dim = int(self.config.get("sound_dim", 64)) + sound_len = int(self.config.get("sound_latent_length", 0)) + if sound_len <= 0: + num_frames = int(self.config.get("target_video_length", 189)) + fps = float(self.config.get("target_fps", 24.0)) + sampling_rate = int(self.config.get("sound_sampling_rate", 48000)) + hop_size = int(self.config.get("sound_hop_size", 1920)) + sound_len = (int(num_frames / fps * sampling_rate) + hop_size - 1) // hop_size + sound_shape = (sound_dim, sound_len) + self.sound_latents = torch.randn(tuple(sound_shape), generator=self.generator, device=AI_DEVICE, dtype=GET_DTYPE()) + + self.action_latents = None + self.action_domain_id = getattr(input_info, "action_domain_id", None) + self.action_condition_frame_indexes = getattr(input_info, "action_condition_frame_indexes", None) + self.raw_action_dim = getattr(input_info, "raw_action_dim", None) + action_latents = getattr(input_info, "action_latents", None) + if action_latents is not None: + self.action_latents = action_latents.to(device=AI_DEVICE, dtype=GET_DTYPE()) + if self.raw_action_dim is not None: + self.action_latents[:, int(self.raw_action_dim) :] = 0 + else: + action_shape = getattr(input_info, "action_latent_shape", None) + if action_shape is not None: + self.action_latents = torch.randn(tuple(action_shape), generator=self.generator, device=AI_DEVICE, dtype=GET_DTYPE()) + if self.raw_action_dim is not None: + self.action_latents[:, int(self.raw_action_dim) :] = 0 self.noise_pred = None + self.noise_pred_sound = None + self.noise_pred_action = None def prepare(self, input_info): self.prepare_latents(input_info) self.unipc = self._build_unipc() self.unipc.set_timesteps(int(self.config["infer_steps"]), device=AI_DEVICE) self.timesteps = self.unipc.timesteps + self.sound_unipc = copy.deepcopy(self.unipc) if self.sound_latents is not None else None + self.action_unipc = copy.deepcopy(self.unipc) if self.action_latents is not None else None self.infer_steps = len(self.timesteps) self.step_index = 0 @@ -568,13 +607,42 @@ def step_post(self): self.latents.unsqueeze(0), return_dict=False, )[0].squeeze(0) + if self.sound_latents is not None: + if self.noise_pred_sound is None: + raise RuntimeError("Cosmos3Scheduler requires noise_pred_sound before sound step_post().") + self.sound_latents = self.sound_unipc.step( + self.noise_pred_sound.unsqueeze(0), + t, + self.sound_latents.unsqueeze(0), + return_dict=False, + )[0].squeeze(0) + if self.action_latents is not None and self.noise_pred_action is not None: + self.action_latents = self.action_unipc.step( + self.noise_pred_action.unsqueeze(0), + t, + self.action_latents.unsqueeze(0), + return_dict=False, + )[0].squeeze(0) + if self.raw_action_dim is not None: + self.action_latents[:, int(self.raw_action_dim) :] = 0 self.noise_pred = None + self.noise_pred_sound = None + self.noise_pred_action = None def clear(self): self.generator = None self.latents = None + self.sound_latents = None + self.action_latents = None self.timesteps = None self.unipc = None + self.sound_unipc = None + self.action_unipc = None self.noise_pred = None + self.noise_pred_sound = None + self.noise_pred_action = None self.vision_condition_frame_indexes = None self.vision_condition_mask = None + self.action_domain_id = None + self.action_condition_frame_indexes = None + self.raw_action_dim = None diff --git a/lightx2v/models/video_encoders/hf/cosmos3/vae.py b/lightx2v/models/video_encoders/hf/cosmos3/vae.py index a9f931ae2..80d8e308d 100644 --- a/lightx2v/models/video_encoders/hf/cosmos3/vae.py +++ b/lightx2v/models/video_encoders/hf/cosmos3/vae.py @@ -584,7 +584,8 @@ def load(self): vae_path = self.config.get("vae_path", os.path.join(self.config["model_path"], "vae")) with open(os.path.join(vae_path, "config.json"), "r") as f: self.vae_config = json.load(f) - self.load_encoder = self.config.get("task") == "i2v" or self.config.get("cosmos3_load_vae_encoder", False) + encoder_tasks = {"i2v", "i2av", "i2va", "v2av"} + self.load_encoder = self.config.get("task") in encoder_tasks or self.config.get("cosmos3_load_vae_encoder", False) model_cls = AutoencoderKLWan if self.load_encoder else AutoencoderKLWanDecodeOnly self.model = model_cls(self.vae_config).to(self.device).to(self.dtype) weight_path = os.path.join(vae_path, "diffusion_pytorch_model.safetensors") @@ -615,7 +616,7 @@ def _to_pil_frames(video: torch.Tensor): @torch.no_grad() def encode(self, video: torch.Tensor): if not self.load_encoder: - raise RuntimeError("Cosmos3WanVAE was loaded without encoder. Set task=i2v or cosmos3_load_vae_encoder=True.") + raise RuntimeError("Cosmos3WanVAE was loaded without encoder. Use a condition task or set cosmos3_load_vae_encoder=True.") if self.cpu_offload: self.model.to(torch.device(AI_DEVICE)) video = video.to(device=next(self.model.parameters()).device, dtype=self.dtype) diff --git a/lightx2v/pipeline.py b/lightx2v/pipeline.py index 73ad0c6fd..54bc7d75e 100755 --- a/lightx2v/pipeline.py +++ b/lightx2v/pipeline.py @@ -11,6 +11,7 @@ from loguru import logger from lightx2v.common.ops import * +from lightx2v.models.runners.cosmos3.cosmos3_runner import Cosmos3Runner # noqa: F401 from lightx2v.models.runners.ernie_image.ernie_image_runner import ErnieImageRunner # noqa: F401 from lightx2v.models.runners.flux2.flux2_runner import Flux2DevRunner, Flux2KleinRunner # noqa: F401 from lightx2v.models.runners.hunyuan_video.hunyuan_video_15_runner import HunyuanVideo15Runner # noqa: F401 @@ -120,6 +121,9 @@ def __init__( elif self.model_cls in ["ltx2"]: self.num_channels_latents = 128 self.audio_mel_bins = 16 + elif self.model_cls in ["cosmos3"]: + self.vae_stride = (4, 16, 16) + self.num_channels_latents = 48 if model_cls in ["qwen-image", "qwen-image-2512", "qwen-image-edit", "qwen-image-edit-2509", "qwen-image-edit-2511"]: self.CONDITION_IMAGE_SIZE = 147456 diff --git a/lightx2v/utils/input_info.py b/lightx2v/utils/input_info.py index c853da584..c865dbf98 100755 --- a/lightx2v/utils/input_info.py +++ b/lightx2v/utils/input_info.py @@ -272,7 +272,12 @@ class I2VAInputInfo: prompt_enhanced: str = field(default_factory=str) negative_prompt: str = field(default_factory=str) image_path: str = field(default_factory=str) + video_path: str = field(default_factory=str) + action_path: str = field(default_factory=str) state_path: str = field(default_factory=str) + action_mode: str = field(default_factory=str) + domain_name: str = field(default_factory=str) + view_point: str = field(default_factory=str) save_result_path: str = field(default_factory=str) save_action_path: str = field(default_factory=str) return_result_tensor: bool = field(default_factory=lambda: False) @@ -307,6 +312,10 @@ class V2AVInputInfo: # Pre-processed reference / control video (pose / canny / depth / track for # motion transfer, or the degraded source video for ICEdit). video_path: str = field(default_factory=str) + action_path: str = field(default_factory=str) + action_mode: str = field(default_factory=str) + domain_name: str = field(default_factory=str) + view_point: str = field(default_factory=str) reference_video_strength: float = field(default_factory=lambda: 1.0) reference_video_frame_cap: Optional[int] = None # Optional: mux audio from this file after save (e.g. original driving video). @@ -314,6 +323,7 @@ class V2AVInputInfo: # v2av mux path is not used because LTX2Runner overrides ``process_images_after_vae_decoder``. mux_audio_video_path: str = field(default_factory=str) save_result_path: str = field(default_factory=str) + save_action_path: str = field(default_factory=str) return_result_tensor: bool = field(default_factory=lambda: False) # shape related resize_mode: str = field(default_factory=str) diff --git a/lightx2v/utils/set_config.py b/lightx2v/utils/set_config.py index fd4795e18..093654caa 100755 --- a/lightx2v/utils/set_config.py +++ b/lightx2v/utils/set_config.py @@ -215,7 +215,7 @@ def auto_calc_config(config): if "infer_steps" not in config and "num_inference_steps" in config: config["infer_steps"] = config["num_inference_steps"] - if config["task"] in ["i2v", "s2v", "rs2v", "ltx2_s2v", "v2av"]: + if config["task"] in ["i2v", "t2av", "i2av", "i2va", "s2v", "rs2v", "ltx2_s2v", "v2av"]: if config["target_video_length"] % config["vae_stride"][0] != 1: logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.") config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1 @@ -237,6 +237,11 @@ def auto_calc_config(config): config["vae_scale_factor_spatial"] = int(vae_config.get("scale_factor_spatial", 16)) config["vae_scale_factor_temporal"] = int(vae_config.get("scale_factor_temporal", 4)) config["vae_scale_factor"] = config["vae_scale_factor_spatial"] + if config["model_cls"] == "cosmos3" and os.path.exists(os.path.join(config["model_path"], "sound_tokenizer", "config.json")): + with open(os.path.join(config["model_path"], "sound_tokenizer", "config.json"), "r") as f: + sound_config = json.load(f) + config["sound_sampling_rate"] = int(sound_config.get("sampling_rate", 48000)) + config["sound_hop_size"] = int(sound_config.get("hop_size", 1920)) return config diff --git a/scripts/cosmos3/cosmos3_super_omni_action_fd_agibotworld.sh b/scripts/cosmos3/cosmos3_super_omni_action_fd_agibotworld.sh new file mode 100644 index 000000000..bbee564cc --- /dev/null +++ b/scripts/cosmos3/cosmos3_super_omni_action_fd_agibotworld.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=/data/nvme5/gushiqiao/codes/LightX2V +model_path=/data/nvme5/gushiqiao/models/Cosmos3-Super +image_path=${model_path}/assets/example_action_fd_agibotworld_first_frame.png +action_path=${model_path}/assets/example_action_fd_agibotworld_action_chunks.json + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls cosmos3 \ +--task i2va \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld.json \ +--prompt "" \ +--image_path ${image_path} \ +--action_path ${action_path} \ +--save_result_path ${lightx2v_path}/save_results/cosmos3_super_action_fd_agibotworld.mp4 \ +--seed 0 diff --git a/scripts/cosmos3/cosmos3_super_omni_action_fd_agibotworld_multichunk.sh b/scripts/cosmos3/cosmos3_super_omni_action_fd_agibotworld_multichunk.sh new file mode 100644 index 000000000..2e3929426 --- /dev/null +++ b/scripts/cosmos3/cosmos3_super_omni_action_fd_agibotworld_multichunk.sh @@ -0,0 +1,23 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=/data/nvme5/gushiqiao/codes/LightX2V +model_path=/data/nvme5/gushiqiao/models/Cosmos3-Super +image_path=${model_path}/assets/example_action_fd_agibotworld_first_frame.png +action_path=${model_path}/assets/example_action_fd_agibotworld_action_chunks.json + +export CUDA_VISIBLE_DEVICES=3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls cosmos3 \ +--task i2va \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld_multichunk.json \ +--prompt "" \ +--image_path ${image_path} \ +--action_path ${action_path} \ +--save_result_path ${lightx2v_path}/save_results/cosmos3_super_action_fd_agibotworld_multichunk.mp4 \ +--seed 0 diff --git a/scripts/cosmos3/cosmos3_super_omni_action_id_av.sh b/scripts/cosmos3/cosmos3_super_omni_action_id_av.sh new file mode 100644 index 000000000..8ec4fd0dd --- /dev/null +++ b/scripts/cosmos3/cosmos3_super_omni_action_id_av.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=/data/nvme5/gushiqiao/codes/LightX2V +model_path=/data/nvme5/gushiqiao/models/Cosmos3-Super +video_path=${model_path}/assets/example_action_id_av_0_input.mp4 + +export CUDA_VISIBLE_DEVICES=1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls cosmos3 \ +--task v2av \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/cosmos3/cosmos3_super_omni_action_id_av.json \ +--prompt "You are an autonomous vehicle planning system." \ +--video_path ${video_path} \ +--save_result_path ${lightx2v_path}/save_results/cosmos3_super_action_id_av.mp4 \ +--save_action_path ${lightx2v_path}/save_results/cosmos3_super_action_id_av.json \ +--seed 0 diff --git a/scripts/cosmos3/cosmos3_super_omni_i2av.sh b/scripts/cosmos3/cosmos3_super_omni_i2av.sh new file mode 100644 index 000000000..f957c6cf8 --- /dev/null +++ b/scripts/cosmos3/cosmos3_super_omni_i2av.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=/data/nvme5/gushiqiao/codes/LightX2V +model_path=/data/nvme5/gushiqiao/models/Cosmos3-Super +prompt_path=${model_path}/assets/example_i2v_prompt.json +negative_prompt_path=${model_path}/assets/negative_prompt.json +image_path=${model_path}/assets/example_i2v_input.jpg + +export CUDA_VISIBLE_DEVICES=2 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls cosmos3 \ +--task i2av \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/cosmos3/cosmos3_super_omni_i2av.json \ +--prompt ${prompt_path} \ +--negative_prompt ${negative_prompt_path} \ +--image_path ${image_path} \ +--save_result_path ${lightx2v_path}/save_results/cosmos3_super_omni_i2av.mp4 \ +--seed 17 diff --git a/scripts/cosmos3/cosmos3_super_omni_i2v.sh b/scripts/cosmos3/cosmos3_super_omni_i2v.sh new file mode 100644 index 000000000..c42278f1c --- /dev/null +++ b/scripts/cosmos3/cosmos3_super_omni_i2v.sh @@ -0,0 +1,24 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=/data/nvme5/gushiqiao/codes/LightX2V +model_path=/data/nvme5/gushiqiao/models/Cosmos3-Super +prompt_path=${model_path}/assets/example_i2v_prompt.json +negative_prompt_path=${model_path}/assets/negative_prompt.json +image_path=${model_path}/assets/example_i2v_input.jpg + +export CUDA_VISIBLE_DEVICES=1 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls cosmos3 \ +--task i2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/cosmos3/cosmos3_super_omni_i2v.json \ +--prompt ${prompt_path} \ +--negative_prompt ${negative_prompt_path} \ +--image_path ${image_path} \ +--save_result_path ${lightx2v_path}/save_results/cosmos3_super_omni_i2v.mp4 \ +--seed 17 diff --git a/scripts/cosmos3/cosmos3_super_omni_t2av.sh b/scripts/cosmos3/cosmos3_super_omni_t2av.sh new file mode 100644 index 000000000..cf131d226 --- /dev/null +++ b/scripts/cosmos3/cosmos3_super_omni_t2av.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=/data/nvme5/gushiqiao/codes/LightX2V +model_path=/data/nvme5/gushiqiao/models/Cosmos3-Super +prompt_path=${model_path}/assets/example_t2vs_prompt.json +negative_prompt_path=${model_path}/assets/negative_prompt.json + +export CUDA_VISIBLE_DEVICES=3 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls cosmos3 \ +--task t2av \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/cosmos3/cosmos3_super_omni_t2av.json \ +--prompt ${prompt_path} \ +--negative_prompt ${negative_prompt_path} \ +--save_result_path ${lightx2v_path}/save_results/cosmos3_super_omni_t2av.mp4 \ +--seed 17 diff --git a/scripts/cosmos3/cosmos3_super_omni_t2v.sh b/scripts/cosmos3/cosmos3_super_omni_t2v.sh new file mode 100644 index 000000000..4272e64a9 --- /dev/null +++ b/scripts/cosmos3/cosmos3_super_omni_t2v.sh @@ -0,0 +1,22 @@ +#!/bin/bash + +# set path firstly +lightx2v_path=/data/nvme5/gushiqiao/codes/LightX2V +model_path=/data/nvme5/gushiqiao/models/Cosmos3-Super +prompt_path=${model_path}/assets/example_t2v_prompt.json +negative_prompt_path=${model_path}/assets/negative_prompt.json + +export CUDA_VISIBLE_DEVICES=0 + +# set environment variables +source ${lightx2v_path}/scripts/base/base.sh + +python -m lightx2v.infer \ +--model_cls cosmos3 \ +--task t2v \ +--model_path $model_path \ +--config_json ${lightx2v_path}/configs/cosmos3/cosmos3_super_omni_t2v.json \ +--prompt ${prompt_path} \ +--negative_prompt ${negative_prompt_path} \ +--save_result_path ${lightx2v_path}/save_results/cosmos3_super_omni_t2v.mp4 \ +--seed 123