Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions configs/cosmos3/cosmos3_super_omni_action_fd_agibotworld.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
{
"infer_steps": 30,
"sample_guide_scale": 1.0,
"sample_shift": 10.0,
"target_height": 720,
"target_width": 640,
"target_video_length": 17,
"target_fps": 10.0,
"enable_cfg": true,
"action_mode": "forward_dynamics",
"domain_name": "agibotworld",
"view_point": "concat_view",
"action_chunk_size": 16,
"action_chunk_index": 0,
"feature_caching": "NoCaching",
"rms_norm_type": "one-pass",
"attn_rms_norm_type": "one-pass",
"rope_type": "triton",
"self_attn_type": "flash_attn3",
"causal_self_attn_type": "flash_attn3",
"add_resolution_template": false,
"add_duration_template": false,
"use_system_prompt": false,
"cosmos3_meta_init": true,
"vae_cpu_offload": false,
"cpu_offload": false,
"offload_granularity": "block"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"infer_steps": 30,
"sample_guide_scale": 1.0,
"sample_shift": 10.0,
"target_height": 720,
"target_width": 640,
"target_video_length": 17,
"target_fps": 10.0,
"enable_cfg": true,
"action_mode": "forward_dynamics",
"domain_name": "agibotworld",
"view_point": "concat_view",
"action_chunk_size": 16,
"action_chunk_index": 0,
"action_multichunk": true,
"action_num_chunks": 4,
"feature_caching": "NoCaching",
"rms_norm_type": "one-pass",
"attn_rms_norm_type": "one-pass",
"rope_type": "triton",
"self_attn_type": "flash_attn3",
"causal_self_attn_type": "flash_attn3",
"add_resolution_template": false,
"add_duration_template": false,
"use_system_prompt": false,
"cosmos3_meta_init": true,
"vae_cpu_offload": false,
"cpu_offload": false,
"offload_granularity": "block"
}
27 changes: 27 additions & 0 deletions configs/cosmos3/cosmos3_super_omni_action_id_av.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
{
"infer_steps": 30,
"sample_guide_scale": 1.0,
"sample_shift": 10.0,
"target_height": 480,
"target_width": 832,
"target_video_length": 61,
"target_fps": 10.0,
"enable_cfg": true,
"action_mode": "inverse_dynamics",
"domain_name": "av",
"view_point": "ego_view",
"action_chunk_size": 60,
"feature_caching": "NoCaching",
"rms_norm_type": "one-pass",
"attn_rms_norm_type": "one-pass",
"rope_type": "triton",
"self_attn_type": "flash_attn3",
"causal_self_attn_type": "flash_attn3",
"add_resolution_template": false,
"add_duration_template": false,
"use_system_prompt": false,
"cosmos3_meta_init": true,
"vae_cpu_offload": false,
"cpu_offload": false,
"offload_granularity": "block"
}
23 changes: 23 additions & 0 deletions configs/cosmos3/cosmos3_super_omni_i2av.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"infer_steps": 35,
"sample_guide_scale": 6.0,
"sample_shift": 10.0,
"target_height": 720,
"target_width": 1280,
"target_video_length": 189,
"target_fps": 24.0,
"enable_cfg": true,
"enable_sound": true,
"feature_caching": "NoCaching",
"rms_norm_type": "one-pass",
"attn_rms_norm_type": "one-pass",
"rope_type": "triton",
"self_attn_type": "flash_attn3",
"causal_self_attn_type": "flash_attn3",
"add_resolution_template": false,
"add_duration_template": false,
"cosmos3_meta_init": true,
"vae_cpu_offload": false,
"cpu_offload": false,
"offload_granularity": "block"
}
22 changes: 22 additions & 0 deletions configs/cosmos3/cosmos3_super_omni_i2v.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"infer_steps": 35,
"sample_guide_scale": 6.0,
"sample_shift": 10.0,
"target_height": 720,
"target_width": 1280,
"target_video_length": 189,
"target_fps": 24.0,
"enable_cfg": true,
"feature_caching": "NoCaching",
"rms_norm_type": "one-pass",
"attn_rms_norm_type": "one-pass",
"rope_type": "triton",
"self_attn_type": "flash_attn3",
"causal_self_attn_type": "flash_attn3",
"add_resolution_template": false,
"add_duration_template": false,
"cosmos3_meta_init": true,
"vae_cpu_offload": false,
"cpu_offload": false,
"offload_granularity": "block"
}
23 changes: 23 additions & 0 deletions configs/cosmos3/cosmos3_super_omni_t2av.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
{
"infer_steps": 35,
"sample_guide_scale": 6.0,
"sample_shift": 10.0,
"target_height": 720,
"target_width": 1280,
"target_video_length": 189,
"target_fps": 24.0,
"enable_cfg": true,
"enable_sound": true,
"feature_caching": "NoCaching",
"rms_norm_type": "one-pass",
"attn_rms_norm_type": "one-pass",
"rope_type": "triton",
"self_attn_type": "flash_attn3",
"causal_self_attn_type": "flash_attn3",
"add_resolution_template": false,
"add_duration_template": false,
"cosmos3_meta_init": true,
"vae_cpu_offload": false,
"cpu_offload": false,
"offload_granularity": "block"
}
22 changes: 22 additions & 0 deletions configs/cosmos3/cosmos3_super_omni_t2v.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"infer_steps": 35,
"sample_guide_scale": 6.0,
"sample_shift": 10.0,
"target_height": 720,
"target_width": 1280,
"target_video_length": 189,
"target_fps": 24.0,
"enable_cfg": true,
"feature_caching": "NoCaching",
"rms_norm_type": "one-pass",
"attn_rms_norm_type": "one-pass",
"rope_type": "triton",
"self_attn_type": "flash_attn3",
"causal_self_attn_type": "flash_attn3",
"add_resolution_template": false,
"add_duration_template": false,
"cosmos3_meta_init": true,
"vae_cpu_offload": false,
"cpu_offload": false,
"offload_granularity": "block"
}
22 changes: 22 additions & 0 deletions configs/cosmos3/cosmos3_super_t2v.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"infer_steps": 35,
"sample_guide_scale": 6.0,
"sample_shift": 10.0,
"target_height": 720,
"target_width": 1280,
"target_video_length": 189,
"target_fps": 24.0,
"enable_cfg": true,
"feature_caching": "NoCaching",
"rms_norm_type": "one-pass",
"attn_rms_norm_type": "one-pass",
"rope_type": "triton",
"self_attn_type": "flash_attn3",
"causal_self_attn_type": "flash_attn3",
"add_resolution_template": false,
"add_duration_template": false,
"cosmos3_meta_init": true,
"vae_cpu_offload": false,
"cpu_offload": false,
"offload_granularity": "block"
}
5 changes: 5 additions & 0 deletions lightx2v/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,11 @@ def main():
default=None,
help="Directory path for lingbot camera/action control files (poses.npy, intrinsics.npy, optional action.npy).",
)
parser.add_argument("--action_mode", type=str, default=None, choices=["forward_dynamics", "inverse_dynamics", "policy"], help="Cosmos3 action mode.")
parser.add_argument("--domain_name", type=str, default=None, help="Cosmos3 action embodiment domain name.")
parser.add_argument("--view_point", type=str, default=None, help="Cosmos3 action viewpoint label.")
parser.add_argument("--action_chunk_size", type=int, default=None, help="Cosmos3 action chunk size.")
parser.add_argument("--action_chunk_index", type=int, default=None, help="Cosmos3 action chunk index when action_path contains action_chunks.")
parser.add_argument(
"--action_ckpt",
type=str,
Expand Down
Empty file.
Empty file.
Empty file.
138 changes: 138 additions & 0 deletions lightx2v/models/audio_encoders/hf/cosmos3/sound_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import json
import math
import os

import torch
import torch.nn as nn
from safetensors import safe_open
from torch.nn.utils import weight_norm


class Snake1d(nn.Module):
def __init__(self, hidden_dim, logscale=True):
super().__init__()
self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1))
self.logscale = logscale

def forward(self, hidden_states):
shape = hidden_states.shape
alpha = self.alpha if not self.logscale else torch.exp(self.alpha)
beta = self.beta if not self.logscale else torch.exp(self.beta)
hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2)
return hidden_states.reshape(shape)


class Cosmos3AudioResidualUnit(nn.Module):
def __init__(self, dimension, dilation=1):
super().__init__()
pad = ((7 - 1) * dilation) // 2
self.snake1 = Snake1d(dimension)
self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad))
self.snake2 = Snake1d(dimension)
self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1))

def forward(self, hidden_state):
output_tensor = self.conv1(self.snake1(hidden_state))
output_tensor = self.conv2(self.snake2(output_tensor))
padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
if padding > 0:
hidden_state = hidden_state[..., padding:-padding]
return hidden_state + output_tensor


class Cosmos3AudioDecoderBlock(nn.Module):
def __init__(self, input_dim, output_dim, stride=1, output_padding=0):
super().__init__()
self.snake1 = Snake1d(input_dim)
self.conv_t1 = weight_norm(
nn.ConvTranspose1d(
input_dim,
output_dim,
kernel_size=2 * stride,
stride=stride,
padding=math.ceil(stride / 2),
output_padding=output_padding,
)
)
self.res_unit1 = Cosmos3AudioResidualUnit(output_dim, dilation=1)
self.res_unit2 = Cosmos3AudioResidualUnit(output_dim, dilation=3)
self.res_unit3 = Cosmos3AudioResidualUnit(output_dim, dilation=9)

def forward(self, hidden_state):
hidden_state = self.snake1(hidden_state)
hidden_state = self.conv_t1(hidden_state)
hidden_state = self.res_unit1(hidden_state)
hidden_state = self.res_unit2(hidden_state)
hidden_state = self.res_unit3(hidden_state)
return hidden_state


class Cosmos3AudioDecoder(nn.Module):
def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples):
super().__init__()
channel_multiples = [1] + list(channel_multiples)
self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3))
block = []
for stride_index, stride in enumerate(upsampling_ratios):
block.append(
Cosmos3AudioDecoderBlock(
input_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index],
output_dim=channels * channel_multiples[len(upsampling_ratios) - stride_index - 1],
stride=stride,
output_padding=stride % 2,
)
)
self.block = nn.ModuleList(block)
self.snake1 = Snake1d(channels)
self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False))

def forward(self, hidden_state):
hidden_state = self.conv1(hidden_state)
for layer in self.block:
hidden_state = layer(hidden_state)
hidden_state = self.snake1(hidden_state)
return self.conv2(hidden_state)


class Cosmos3SoundTokenizer(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.sampling_rate = int(config.get("sampling_rate", 48000))
self.hop_size = int(config.get("hop_size") or math.prod(config.get("dec_strides", [2, 4, 5, 6, 8])))
self.decoder = Cosmos3AudioDecoder(
channels=int(config.get("dec_dim", 320)),
input_channels=int(config.get("vocoder_input_dim", 64)),
audio_channels=int(config.get("dec_out_channels", 2)),
upsampling_ratios=list(reversed(config.get("dec_strides", [2, 4, 5, 6, 8]))),
channel_multiples=list(config.get("dec_c_mults", [1, 2, 4, 8, 16])),
)

@classmethod
def from_pretrained(cls, model_path, device, dtype):
tokenizer_path = os.path.join(model_path, "sound_tokenizer")
with open(os.path.join(tokenizer_path, "config.json"), "r") as f:
config = json.load(f)
model = cls(config)
weight_path = os.path.join(tokenizer_path, "diffusion_pytorch_model.safetensors")
state_dict = {}
with safe_open(weight_path, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith("decoder."):
state_dict[key] = f.get_tensor(key)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
unexpected = [key for key in unexpected if key.startswith("decoder.")]
missing = [key for key in missing if key.startswith("decoder.")]
if missing or unexpected:
raise RuntimeError(f"Failed to load Cosmos3 sound decoder, missing={missing}, unexpected={unexpected}")
return model.to(device=device, dtype=dtype).eval()

@torch.no_grad()
def decode(self, latents):
squeeze = latents.ndim == 2
if squeeze:
latents = latents.unsqueeze(0)
audio = self.decoder(latents).clamp(-1.0, 1.0)
return audio.squeeze(0) if squeeze else audio
15 changes: 15 additions & 0 deletions lightx2v/models/networks/cosmos3/infer/module_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@ class Cosmos3PreInferModuleOutput:
vision_token_shapes: list
vision_noisy_frame_indexes: list
original_latent_shapes: list
sound_mse_loss_indexes: torch.Tensor | None = None
sound_token_shapes: list | None = None
sound_noisy_frame_indexes: list | None = None
action_mse_loss_indexes: torch.Tensor | None = None
action_token_shapes: list | None = None
action_noisy_frame_indexes: list | None = None
action_domain_ids: torch.Tensor | None = None
raw_action_dim: int | None = None
seq_p_gen_len: int | None = None
seq_p_gen_padding_size: int = 0
seq_p_local_gen_len: int | None = None
Expand All @@ -21,3 +29,10 @@ class Cosmos3PreInferModuleOutput:
class Cosmos3TransformerInferModuleOutput:
und_seq: torch.Tensor
gen_seq: torch.Tensor


@dataclass
class Cosmos3PostInferModuleOutput:
vision: torch.Tensor
sound: torch.Tensor | None = None
action: torch.Tensor | None = None
Loading
Loading