Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 90 additions & 3 deletions fastvideo/configs/models/dits/matrixgame.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@dataclass
class MatrixGameWanVideoArchConfig(WanVideoArchConfig):
class MatrixGame2WanVideoArchConfig(WanVideoArchConfig):
# Override param_names_mapping to remove patch_embedding transformation
# because MatrixGame checkpoints already have patch_embedding.proj format
param_names_mapping: dict = field(
Expand Down Expand Up @@ -67,7 +67,94 @@ def _is_transformer_block(param_name: str, module: torch.nn.Module) -> bool:


@dataclass
class MatrixGameWanVideoConfig(WanVideoConfig):
arch_config: MatrixGameWanVideoArchConfig = field(default_factory=MatrixGameWanVideoArchConfig)
class MatrixGame2WanVideoConfig(WanVideoConfig):
arch_config: MatrixGame2WanVideoArchConfig = field(default_factory=MatrixGame2WanVideoArchConfig)
prefix: str = "Wan"
_compile_conditions: list = field(default_factory=lambda: [_is_transformer_block])


@dataclass
class MatrixGame3WanVideoArchConfig(WanVideoArchConfig):
param_names_mapping: dict = field(
default_factory=lambda: {
r"^patch_embedding\.(weight|bias)$": r"patch_embedding.proj.\1",
r"^patch_embedding_wancamctrl\.(.*)$": r"camera_patch_embedding.proj.\1",
r"^time_embedding\.0\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_in.\1",
r"^time_embedding\.2\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_out.\1",
r"^time_projection\.1\.(.*)$": r"condition_embedder.time_modulation.linear.\1",
r"^head\.head\.(.*)$": r"proj_out.\1",
r"^head\.modulation$": r"scale_shift_table",
r"^blocks\.(\d+)\.self_attn\.q\.(.*)$": r"blocks.\1.to_q.\2",
r"^blocks\.(\d+)\.self_attn\.k\.(.*)$": r"blocks.\1.to_k.\2",
r"^blocks\.(\d+)\.self_attn\.v\.(.*)$": r"blocks.\1.to_v.\2",
r"^blocks\.(\d+)\.self_attn\.o\.(.*)$": r"blocks.\1.to_out.\2",
r"^blocks\.(\d+)\.self_attn\.norm_q\.(.*)$": r"blocks.\1.norm_q.\2",
r"^blocks\.(\d+)\.self_attn\.norm_k\.(.*)$": r"blocks.\1.norm_k.\2",
r"^blocks\.(\d+)\.cross_attn\.q\.(.*)$": r"blocks.\1.attn2.to_q.\2",
r"^blocks\.(\d+)\.cross_attn\.k\.(.*)$": r"blocks.\1.attn2.to_k.\2",
r"^blocks\.(\d+)\.cross_attn\.v\.(.*)$": r"blocks.\1.attn2.to_v.\2",
r"^blocks\.(\d+)\.cross_attn\.o\.(.*)$": r"blocks.\1.attn2.to_out.\2",
r"^blocks\.(\d+)\.cross_attn\.norm_q\.(.*)$": r"blocks.\1.attn2.norm_q.\2",
r"^blocks\.(\d+)\.cross_attn\.norm_k\.(.*)$": r"blocks.\1.attn2.norm_k.\2",
r"^blocks\.(\d+)\.ffn\.0\.(.*)$": r"blocks.\1.ffn.fc_in.\2",
r"^blocks\.(\d+)\.ffn\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2",
r"^blocks\.(\d+)\.norm3\.(.*)$": r"blocks.\1.self_attn_residual_norm.norm.\2",
r"^blocks\.(\d+)\.modulation$": r"blocks.\1.scale_shift_table",
r"^patch_embedding\.(?!proj\.)(.*)$": r"patch_embedding.proj.\1",
r"^condition_embedder\.text_embedder\.linear_1\.(.*)$": r"condition_embedder.text_embedder.fc_in.\1",
r"^condition_embedder\.text_embedder\.linear_2\.(.*)$": r"condition_embedder.text_embedder.fc_out.\1",
r"^condition_embedder\.time_embedder\.linear_1\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_in.\1",
r"^condition_embedder\.time_embedder\.linear_2\.(.*)$": r"condition_embedder.time_embedder.mlp.fc_out.\1",
r"^condition_embedder\.time_proj\.(.*)$": r"condition_embedder.time_modulation.linear.\1",
r"^blocks\.(\d+)\.attn1\.to_q\.(.*)$": r"blocks.\1.to_q.\2",
r"^blocks\.(\d+)\.attn1\.to_k\.(.*)$": r"blocks.\1.to_k.\2",
r"^blocks\.(\d+)\.attn1\.to_v\.(.*)$": r"blocks.\1.to_v.\2",
r"^blocks\.(\d+)\.attn1\.to_out\.0\.(.*)$": r"blocks.\1.to_out.\2",
r"^blocks\.(\d+)\.attn1\.norm_q\.(.*)$": r"blocks.\1.norm_q.\2",
r"^blocks\.(\d+)\.attn1\.norm_k\.(.*)$": r"blocks.\1.norm_k.\2",
r"^blocks\.(\d+)\.attn2\.to_out\.0\.(.*)$": r"blocks.\1.attn2.to_out.\2",
r"^blocks\.(\d+)\.ffn\.net\.0\.proj\.(.*)$": r"blocks.\1.ffn.fc_in.\2",
r"^blocks\.(\d+)\.ffn\.net\.2\.(.*)$": r"blocks.\1.ffn.fc_out.\2",
r"^blocks\.(\d+)\.norm2\.(.*)$": r"blocks.\1.self_attn_residual_norm.norm.\2",
})
patch_size: tuple[int, int, int] = (1, 2, 2)
in_channels: int = 48
out_channels: int = 48
num_attention_heads: int = 40
attention_head_dim: int = 128
ffn_dim: int = 13824
num_layers: int = 40
text_len: int = 512
image_dim: int = 0
use_text_crossattn: bool = True
use_memory: bool = True
sigma_theta: float = 0.8
camera_embed_in_channels: int = 1536
action_config: dict = field(
default_factory=lambda: {
"blocks": list(range(40)),
"enable_mouse": True,
"enable_keyboard": True,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 5120,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [8, 28, 28],
"patch_size": [1, 2, 2],
"qk_norm": True,
"qkv_bias": False,
"rope_dim_list": [8, 28, 28],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3,
})


@dataclass
class MatrixGame3WanVideoConfig(WanVideoConfig):
arch_config: MatrixGame3WanVideoArchConfig = field(default_factory=MatrixGame3WanVideoArchConfig)
prefix: str = "Wan"
_compile_conditions: list = field(default_factory=lambda: [_is_transformer_block])
3 changes: 2 additions & 1 deletion fastvideo/configs/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastvideo.configs.pipelines.hunyuangamecraft import HunyuanGameCraftPipelineConfig
from fastvideo.configs.pipelines.hyworld import HYWorldConfig
from fastvideo.configs.pipelines.ltx2 import LTX2T2VConfig
from fastvideo.configs.pipelines.matrixgame import MatrixGame2I2V480PConfig, MatrixGame3I2V720PConfig
from fastvideo.registry import get_pipeline_config_cls_from_name
from fastvideo.configs.pipelines.wan import (SelfForcingWanT2V480PConfig, WanI2V480PConfig, WanI2V720PConfig,
WanT2V480PConfig, WanT2V720PConfig)
Expand All @@ -14,5 +15,5 @@
"HunyuanConfig", "FastHunyuanConfig", "HunyuanGameCraftPipelineConfig", "PipelineConfig", "Hunyuan15T2V480PConfig",
"Hunyuan15T2V720PConfig", "WanT2V480PConfig", "WanI2V480PConfig", "WanT2V720PConfig", "WanI2V720PConfig",
"SelfForcingWanT2V480PConfig", "CosmosConfig", "Cosmos25Config", "LTX2T2VConfig", "HYWorldConfig",
"get_pipeline_config_cls_from_name"
"MatrixGame2I2V480PConfig", "MatrixGame3I2V720PConfig", "get_pipeline_config_cls_from_name"
]
36 changes: 36 additions & 0 deletions fastvideo/configs/pipelines/matrixgame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass, field

from fastvideo.configs.models import DiTConfig, EncoderConfig
from fastvideo.configs.models.dits.matrixgame import MatrixGame2WanVideoConfig, MatrixGame3WanVideoConfig
from fastvideo.configs.models.encoders import WAN2_1ControlCLIPVisionConfig
from fastvideo.configs.pipelines.wan import WanI2V480PConfig, WanT2V480PConfig


@dataclass
class MatrixGame2BaseI2V480PConfig(WanI2V480PConfig):
dit_config: DiTConfig = field(default_factory=MatrixGame2WanVideoConfig)
flow_shift: float | None = 5.0


@dataclass
class MatrixGame2I2V480PConfig(WanI2V480PConfig):
dit_config: DiTConfig = field(default_factory=MatrixGame2WanVideoConfig)
image_encoder_config: EncoderConfig = field(default_factory=WAN2_1ControlCLIPVisionConfig)
is_causal: bool = True
flow_shift: float | None = 5.0
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 666, 333])
warp_denoising_step: bool = True
context_noise: int = 0
num_frames_per_block: int = 3


@dataclass
class MatrixGame3I2V720PConfig(WanT2V480PConfig):
dit_config: DiTConfig = field(default_factory=MatrixGame3WanVideoConfig)
flow_shift: float | None = 5.0
vae_precision: str = "fp32"

def __post_init__(self) -> None:
self.vae_config.load_encoder = True
self.vae_config.load_decoder = True
25 changes: 0 additions & 25 deletions fastvideo/configs/pipelines/wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from fastvideo.configs.models import DiTConfig, EncoderConfig, VAEConfig
from fastvideo.configs.models.dits import WanVideoConfig
from fastvideo.configs.models.dits.matrixgame import MatrixGameWanVideoConfig
from fastvideo.configs.models.encoders import (BaseEncoderOutput, CLIPVisionConfig, T5Config,
WAN2_1ControlCLIPVisionConfig)
from fastvideo.configs.models.vaes import WanVAEConfig
Expand Down Expand Up @@ -177,27 +176,3 @@ class SelfForcingWan2_2_T2V480PConfig(Wan2_2_T2V_A14B_Config):
def __post_init__(self) -> None:
self.vae_config.load_encoder = True
self.vae_config.load_decoder = True


# =============================================
# ============= Matrix Game ===================
# =============================================
@dataclass
class MatrixGameBaseI2V480PConfig(WanI2V480PConfig):
dit_config: DiTConfig = field(default_factory=MatrixGameWanVideoConfig)
flow_shift: float | None = 5.0


@dataclass
class MatrixGameI2V480PConfig(WanI2V480PConfig):
dit_config: DiTConfig = field(default_factory=MatrixGameWanVideoConfig)

image_encoder_config: EncoderConfig = field(default_factory=WAN2_1ControlCLIPVisionConfig)

is_causal: bool = True
flow_shift: float | None = 5.0
dmd_denoising_steps: list[int] | None = field(default_factory=lambda: [1000, 666, 333])
warp_denoising_step: bool = True
context_noise: int = 0
num_frames_per_block: int = 3
# sliding_window_num_frames: int = 15
28 changes: 28 additions & 0 deletions fastvideo/configs/sample/matrixgame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass

from fastvideo.configs.sample.base import SamplingParam


@dataclass
class MatrixGame2SamplingParam(SamplingParam):
height: int = 352
width: int = 640
num_frames: int = 57
fps: int = 25
guidance_scale: float = 1.0
num_inference_steps: int = 3
negative_prompt: str | None = None


@dataclass
class MatrixGame3SamplingParam(SamplingParam):
height: int = 720
width: int = 1280
num_frames: int = 57
fps: int = 25
guidance_scale: float = 1.0
num_inference_steps: int = 3
negative_prompt: str | None = None
num_iterations: int = 1
use_base_model: bool = False
11 changes: 0 additions & 11 deletions fastvideo/configs/sample/wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,3 @@ class SelfForcingWan2_2_T2V_A14B_480P_SamplingParam(Wan2_2_T2V_A14B_SamplingPara
height: int = 448
width: int = 832
fps: int = 16


@dataclass
class MatrixGame2_SamplingParam(SamplingParam):
height: int = 352
width: int = 640
num_frames: int = 57
fps: int = 25
guidance_scale: float = 1.0
num_inference_steps: int = 3
negative_prompt: str | None = None
6 changes: 3 additions & 3 deletions fastvideo/models/dits/matrixgame/causal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)

from fastvideo.attention import LocalAttention
from fastvideo.configs.models.dits.matrixgame import MatrixGameWanVideoConfig
from fastvideo.configs.models.dits.matrixgame import MatrixGame2WanVideoConfig
from fastvideo.distributed.parallel_state import get_sp_world_size
from fastvideo.layers.layernorm import (
FP32LayerNorm,
Expand Down Expand Up @@ -584,7 +584,7 @@ def forward(
return hidden_states.to(orig_dtype)


_DEFAULT_MATRIXGAME_CONFIG = MatrixGameWanVideoConfig()
_DEFAULT_MATRIXGAME_CONFIG = MatrixGame2WanVideoConfig()


class CausalMatrixGameWanModel(BaseDiT):
Expand All @@ -605,7 +605,7 @@ class CausalMatrixGameWanModel(BaseDiT):

def __init__(
self,
config: MatrixGameWanVideoConfig,
config: MatrixGame2WanVideoConfig,
hf_config: dict[str, Any],
**kwargs,
) -> None:
Expand Down
Loading
Loading