Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion fastvideo/configs/models/dits/wanvideo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,65 @@ class WanVideoArchConfig(DiTArchConfig):

param_names_mapping: dict = field(
default_factory=lambda: {
# Official Wan/LingBot checkpoint naming.
r"^patch_embedding\.(.*)$":
r"patch_embedding.proj.\1",
r"^text_embedding\.0\.(.*)$":
r"condition_embedder.text_embedder.fc_in.\1",
r"^text_embedding\.2\.(.*)$":
r"condition_embedder.text_embedder.fc_out.\1",
r"^time_embedding\.0\.(.*)$":
r"condition_embedder.time_embedder.mlp.fc_in.\1",
r"^time_embedding\.2\.(.*)$":
r"condition_embedder.time_embedder.mlp.fc_out.\1",
r"^time_projection\.1\.(.*)$":
r"condition_embedder.time_modulation.linear.\1",
r"^head\.modulation$":
r"scale_shift_table",
r"^head\.head\.(.*)$":
r"proj_out.\1",
r"^blocks\.(\d+)\.modulation$":
r"blocks.\1.scale_shift_table",
r"^blocks\.(\d+)\.self_attn\.q\.(.*)$":
r"blocks.\1.to_q.\2",
r"^blocks\.(\d+)\.self_attn\.k\.(.*)$":
r"blocks.\1.to_k.\2",
r"^blocks\.(\d+)\.self_attn\.v\.(.*)$":
r"blocks.\1.to_v.\2",
r"^blocks\.(\d+)\.self_attn\.o\.(.*)$":
r"blocks.\1.to_out.\2",
r"^blocks\.(\d+)\.self_attn\.norm_q\.(.*)$":
r"blocks.\1.norm_q.\2",
r"^blocks\.(\d+)\.self_attn\.norm_k\.(.*)$":
r"blocks.\1.norm_k.\2",
r"^blocks\.(\d+)\.cross_attn\.q\.(.*)$":
r"blocks.\1.attn2.to_q.\2",
r"^blocks\.(\d+)\.cross_attn\.k\.(.*)$":
r"blocks.\1.attn2.to_k.\2",
r"^blocks\.(\d+)\.cross_attn\.k_img\.(.*)$":
r"blocks.\1.attn2.add_k_proj.\2",
r"^blocks\.(\d+)\.cross_attn\.v\.(.*)$":
r"blocks.\1.attn2.to_v.\2",
r"^blocks\.(\d+)\.cross_attn\.v_img\.(.*)$":
r"blocks.\1.attn2.add_v_proj.\2",
r"^blocks\.(\d+)\.cross_attn\.o\.(.*)$":
r"blocks.\1.attn2.to_out.\2",
r"^blocks\.(\d+)\.cross_attn\.norm_q\.(.*)$":
r"blocks.\1.attn2.norm_q.\2",
r"^blocks\.(\d+)\.cross_attn\.norm_k\.(.*)$":
r"blocks.\1.attn2.norm_k.\2",
r"^blocks\.(\d+)\.cross_attn\.norm_q_img\.(.*)$":
r"blocks.\1.attn2.norm_added_q.\2",
r"^blocks\.(\d+)\.cross_attn\.norm_k_img\.(.*)$":
r"blocks.\1.attn2.norm_added_k.\2",
r"^blocks\.(\d+)\.ffn\.0\.(.*)$":
r"blocks.\1.ffn.fc_in.\2",
r"^blocks\.(\d+)\.ffn\.2\.(.*)$":
r"blocks.\1.ffn.fc_out.\2",
r"^blocks\.(\d+)\.norm3\.(.*)$":
r"blocks.\1.self_attn_residual_norm.norm.\2",

# Diffusers-style naming.
r"^condition_embedder\.text_embedder\.linear_1\.(.*)$":
r"condition_embedder.text_embedder.fc_in.\1",
r"^condition_embedder\.text_embedder\.linear_2\.(.*)$":
Expand Down Expand Up @@ -74,11 +131,16 @@ class WanVideoArchConfig(DiTArchConfig):
})

patch_size: tuple[int, int, int] = (1, 2, 2)
text_len = 512
dim: int | None = None
text_len: int = 512
num_heads: int | None = None
num_attention_heads: int = 40
attention_head_dim: int = 128
in_dim: int | None = None
in_channels: int = 16
out_dim: int | None = None
out_channels: int = 16
model_type: str | None = None
text_dim: int = 4096
freq_dim: int = 256
ffn_dim: int = 13824
Expand All @@ -102,6 +164,15 @@ class WanVideoArchConfig(DiTArchConfig):
sliding_window_num_frames: int = 21

def __post_init__(self):
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
if self.in_dim is not None:
self.in_channels = self.in_dim
if self.out_dim is not None:
self.out_channels = self.out_dim
if self.dim is not None and self.num_attention_heads > 0:
self.attention_head_dim = self.dim // self.num_attention_heads
self.hidden_size = self.dim
super().__post_init__()
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
Comment on lines 166 to 178
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

In the __post_init__ method, self.hidden_size is set on line 175 based on self.dim, but it's then unconditionally overwritten on line 178. This can lead to an incorrect hidden_size if self.dim is not a multiple of self.num_attention_heads due to integer division. For example, if dim was 5001 and num_attention_heads was 40, attention_head_dim would be 125, and hidden_size would be incorrectly set to 40 * 125 = 5000 instead of the intended 5001. The logic should be restructured to correctly prioritize self.dim for hidden_size when it's available.

Suggested change
def __post_init__(self):
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
if self.in_dim is not None:
self.in_channels = self.in_dim
if self.out_dim is not None:
self.out_channels = self.out_dim
if self.dim is not None and self.num_attention_heads > 0:
self.attention_head_dim = self.dim // self.num_attention_heads
self.hidden_size = self.dim
super().__post_init__()
self.out_channels = self.out_channels or self.in_channels
self.hidden_size = self.num_attention_heads * self.attention_head_dim
def __post_init__(self):
if self.num_heads is not None:
self.num_attention_heads = self.num_heads
if self.in_dim is not None:
self.in_channels = self.in_dim
if self.out_dim is not None:
self.out_channels = self.out_dim
super().__post_init__()
if self.dim is not None and self.num_attention_heads > 0:
self.attention_head_dim = self.dim // self.num_attention_heads
self.hidden_size = self.dim
else:
self.hidden_size = self.num_attention_heads * self.attention_head_dim
self.out_channels = self.out_channels or self.in_channels

Expand Down
66 changes: 57 additions & 9 deletions fastvideo/models/loader/component_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from fastvideo.models.loader.fsdp_load import maybe_load_fsdp_model, shard_model
from fastvideo.models.loader.utils import set_default_torch_dtype
from fastvideo.models.loader.weight_utils import (
extract_tensor_state_dict,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
pt_weights_iterator,
Expand Down Expand Up @@ -152,7 +153,7 @@ def _prepare_weights(

use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
allow_patterns = ["*.safetensors", "*.bin"]
allow_patterns = ["*.safetensors", "*.bin", "*.pth"]

if fall_back_to_pt:
allow_patterns += ["*.pt"]
Expand Down Expand Up @@ -244,11 +245,13 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs):
# model_override_args=None,
# )
model_config = get_diffusers_config(model=model_path)
model_config.pop("_class_name", None)
model_config.pop("_name_or_path", None)
model_config.pop("transformers_version", None)
model_config.pop("model_type", None)
model_config.pop("tokenizer_class", None)
model_config.pop("torch_dtype", None)
model_config.pop("architectures", None)
repo_root = os.path.dirname(model_path)
index_path = os.path.join(repo_root, "model_index.json")
gemma_path = ""
Expand Down Expand Up @@ -625,16 +628,44 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs):
vae_cls, _ = ModelRegistry.resolve_model_cls(class_name)
vae = vae_cls(vae_config).to(target_device)

# Find all safetensors files
# Load weights from safetensors first, then fallback to pt/pth/bin.
safetensors_list = glob.glob(
os.path.join(str(model_path), "*.safetensors"))
if not safetensors_list:
raise ValueError(f"No safetensors files found in {model_path}")
# Common case: a single `.safetensors` checkpoint file.
# Some models may be sharded into multiple files; in that case we merge.
loaded = {}
for sf_file in safetensors_list:
loaded.update(safetensors_load_file(sf_file))
if safetensors_list:
for sf_file in safetensors_list:
loaded.update(safetensors_load_file(sf_file))
else:
pt_files: list[str] = []
for pattern in ("*.pt", "*.pth", "*.bin"):
pt_files = glob.glob(os.path.join(str(model_path), pattern))
if pt_files:
break
if not pt_files:
raise ValueError(
f"No VAE weight files found in {model_path}. "
"Expected *.safetensors, *.pt, *.pth, or *.bin."
)
for pt_file in pt_files:
try:
state = torch.load(pt_file, map_location="cpu", weights_only=True)
except Exception:
state = torch.load(pt_file, map_location="cpu", weights_only=False)
loaded.update(extract_tensor_state_dict(state))

# Some legacy checkpoints store VAE weights with an extra top-level prefix.
# Strip known prefixes only when there is no direct key overlap.
model_keys = set(vae.state_dict().keys())
if model_keys and not (model_keys & set(loaded.keys())):
for prefix in ("vae.", "module.", "model.", "generator.",
"generator_ema."):
stripped = {
(k[len(prefix):] if k.startswith(prefix) else k): v
for k, v in loaded.items()
}
if model_keys & set(stripped.keys()):
loaded = stripped
break

# LTX-2 CausalVideoAutoencoder needs per_channel_statistics remapping
if class_name == "CausalVideoAutoencoder" and "vae" in config:
Expand Down Expand Up @@ -751,6 +782,22 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs):

# Config from Diffusers supersedes fastvideo's model config
dit_config = deepcopy(fastvideo_args.pipeline_config.dit_config)
if cls_name == "WanModel":
valid_fields = {
field.name for field in dataclasses.fields(dit_config.arch_config)
}
unknown_fields = sorted(set(config.keys()) - valid_fields)
if unknown_fields:
logger.warning(
"Ignoring %d unsupported WanModel config fields: %s",
len(unknown_fields),
", ".join(unknown_fields),
)
config = {
key: value
for key, value in config.items()
if key in valid_fields
}
dit_config.update_model_arch(config)

model_cls, _ = ModelRegistry.resolve_model_cls(cls_name)
Expand Down Expand Up @@ -810,6 +857,7 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs):
strict_load = not (
cls_name.startswith("Cosmos25")
or cls_name == "Cosmos25Transformer3DModel"
or cls_name == "WanModel"
or getattr(fastvideo_args.pipeline_config, "prefix", "") == "Cosmos25"
)
model = maybe_load_fsdp_model(
Expand Down Expand Up @@ -1005,4 +1053,4 @@ def load_module(
)

# Load the module
return loader.load(component_model_path, fastvideo_args)
return loader.load(component_model_path, fastvideo_args)
50 changes: 49 additions & 1 deletion fastvideo/models/loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,59 @@ def pt_weights_iterator(
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location=device, weights_only=True)
try:
state = torch.load(bin_file, map_location=device, weights_only=True)
except Exception:
# Some legacy checkpoints contain objects unsupported by
# weights_only=True. Fall back for trusted local checkpoints.
state = torch.load(bin_file, map_location=device, weights_only=False)
state = extract_tensor_state_dict(state)
yield from state.items()
del state


def extract_tensor_state_dict(state: object) -> dict[str, torch.Tensor]:
"""Extract a plain tensor state_dict from common checkpoint wrappers."""

def _is_tensor_dict(obj: object) -> bool:
return isinstance(obj, dict) and all(
isinstance(k, str) and isinstance(v, torch.Tensor)
for k, v in obj.items())

def _dfs_find_tensor_dict(obj: object, depth: int = 3) -> dict[str, torch.Tensor] | None:
if depth < 0:
return None
if _is_tensor_dict(obj):
return obj # type: ignore[return-value]
if not isinstance(obj, dict):
return None
for key in (
"state_dict",
"model",
"module",
"model_state_dict",
"ema",
"generator_ema",
"weights",
):
if key in obj:
found = _dfs_find_tensor_dict(obj[key], depth - 1)
if found is not None:
return found
for value in obj.values():
found = _dfs_find_tensor_dict(value, depth - 1)
if found is not None:
return found
return None

found = _dfs_find_tensor_dict(state)
if found is None:
raise ValueError(
"Failed to find a tensor state_dict in checkpoint. "
"Expected a dict[str, torch.Tensor] or a common wrapper around it.")
return found


def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
Expand Down
1 change: 1 addition & 0 deletions fastvideo/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
_IMAGE_TO_VIDEO_DIT_MODELS = {
# "HunyuanVideoTransformer3DModel": ("dits", "hunyuanvideo", "HunyuanVideoDiT"),
"WanTransformer3DModel": ("dits", "wanvideo", "WanTransformer3DModel"),
"WanModel": ("dits", "wanvideo", "WanTransformer3DModel"), # new add
"CausalWanTransformer3DModel": ("dits", "causal_wanvideo", "CausalWanTransformer3DModel"),
"MatrixGameWanModel": ("dits", "matrixgame", "MatrixGameWanModel"),
"CausalMatrixGameWanModel": ("dits", "matrixgame", "CausalMatrixGameWanModel"),
Expand Down
76 changes: 76 additions & 0 deletions fastvideo/pipelines/basic/wan/wan_cam_i2v_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
"""
LingBot/Wan camera-control image-to-video pipeline.

This pipeline targets Wan-style I2V checkpoints that provide two transformer
experts (high-noise + low-noise) but do not require CLIP image embeddings.
"""

from fastvideo.fastvideo_args import FastVideoArgs
from fastvideo.logger import init_logger
from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import (
FlowUniPCMultistepScheduler)
from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline
from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage,
DenoisingStage, ImageVAEEncodingStage,
InputValidationStage,
LatentPreparationStage,
TextEncodingStage,
TimestepPreparationStage)

logger = init_logger(__name__)


class WanCamImageToVideoPipeline(LoRAPipeline, ComposedPipelineBase):

_required_config_modules = [
"text_encoder",
"tokenizer",
"vae",
"transformer",
"transformer_2",
]

def initialize_pipeline(self, fastvideo_args: FastVideoArgs):
self.modules["scheduler"] = FlowUniPCMultistepScheduler(
shift=fastvideo_args.pipeline_config.flow_shift)

def create_pipeline_stages(self, fastvideo_args: FastVideoArgs):
self.add_stage(stage_name="input_validation_stage",
stage=InputValidationStage())

self.add_stage(stage_name="prompt_encoding_stage",
stage=TextEncodingStage(
text_encoders=[self.get_module("text_encoder")],
tokenizers=[self.get_module("tokenizer")],
))

self.add_stage(stage_name="conditioning_stage",
stage=ConditioningStage())

self.add_stage(stage_name="timestep_preparation_stage",
stage=TimestepPreparationStage(
scheduler=self.get_module("scheduler")))

self.add_stage(stage_name="latent_preparation_stage",
stage=LatentPreparationStage(
scheduler=self.get_module("scheduler"),
transformer=self.get_module("transformer")))

self.add_stage(stage_name="image_latent_preparation_stage",
stage=ImageVAEEncodingStage(vae=self.get_module("vae")))

self.add_stage(stage_name="denoising_stage",
stage=DenoisingStage(
transformer=self.get_module("transformer"),
transformer_2=self.get_module("transformer_2", None),
scheduler=self.get_module("scheduler"),
vae=self.get_module("vae"),
pipeline=self))

self.add_stage(stage_name="decoding_stage",
stage=DecodingStage(vae=self.get_module("vae"),
pipeline=self))


EntryClass = WanCamImageToVideoPipeline
11 changes: 11 additions & 0 deletions fastvideo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,17 @@ def _register_configs() -> None:
"Wan-AI/Wan2.2-I2V-A14B-Diffusers",
],
)
register_configs(
sampling_param_cls=Wan2_2_I2V_A14B_SamplingParam,
pipeline_config_cls=Wan2_2_I2V_A14B_Config,
hf_model_paths=[
"robbyant/lingbot-world-base-cam",
],
model_detectors=[
lambda path: "lingbot-world" in path.lower(),
lambda path: "wancamimagetovideopipeline" in path.lower(),
],
)
register_configs(
sampling_param_cls=SelfForcingWan2_1_T2V_1_3B_480P_SamplingParam,
pipeline_config_cls=SelfForcingWanT2V480PConfig,
Expand Down
Loading