diff --git a/fastvideo/configs/models/dits/wanvideo.py b/fastvideo/configs/models/dits/wanvideo.py index 77bac8ebe4..71de2f84bb 100644 --- a/fastvideo/configs/models/dits/wanvideo.py +++ b/fastvideo/configs/models/dits/wanvideo.py @@ -14,8 +14,65 @@ class WanVideoArchConfig(DiTArchConfig): param_names_mapping: dict = field( default_factory=lambda: { + # Official Wan/LingBot checkpoint naming. r"^patch_embedding\.(.*)$": r"patch_embedding.proj.\1", + r"^text_embedding\.0\.(.*)$": + r"condition_embedder.text_embedder.fc_in.\1", + r"^text_embedding\.2\.(.*)$": + r"condition_embedder.text_embedder.fc_out.\1", + r"^time_embedding\.0\.(.*)$": + r"condition_embedder.time_embedder.mlp.fc_in.\1", + r"^time_embedding\.2\.(.*)$": + r"condition_embedder.time_embedder.mlp.fc_out.\1", + r"^time_projection\.1\.(.*)$": + r"condition_embedder.time_modulation.linear.\1", + r"^head\.modulation$": + r"scale_shift_table", + r"^head\.head\.(.*)$": + r"proj_out.\1", + r"^blocks\.(\d+)\.modulation$": + r"blocks.\1.scale_shift_table", + r"^blocks\.(\d+)\.self_attn\.q\.(.*)$": + r"blocks.\1.to_q.\2", + r"^blocks\.(\d+)\.self_attn\.k\.(.*)$": + r"blocks.\1.to_k.\2", + r"^blocks\.(\d+)\.self_attn\.v\.(.*)$": + r"blocks.\1.to_v.\2", + r"^blocks\.(\d+)\.self_attn\.o\.(.*)$": + r"blocks.\1.to_out.\2", + r"^blocks\.(\d+)\.self_attn\.norm_q\.(.*)$": + r"blocks.\1.norm_q.\2", + r"^blocks\.(\d+)\.self_attn\.norm_k\.(.*)$": + r"blocks.\1.norm_k.\2", + r"^blocks\.(\d+)\.cross_attn\.q\.(.*)$": + r"blocks.\1.attn2.to_q.\2", + r"^blocks\.(\d+)\.cross_attn\.k\.(.*)$": + r"blocks.\1.attn2.to_k.\2", + r"^blocks\.(\d+)\.cross_attn\.k_img\.(.*)$": + r"blocks.\1.attn2.add_k_proj.\2", + r"^blocks\.(\d+)\.cross_attn\.v\.(.*)$": + r"blocks.\1.attn2.to_v.\2", + r"^blocks\.(\d+)\.cross_attn\.v_img\.(.*)$": + r"blocks.\1.attn2.add_v_proj.\2", + r"^blocks\.(\d+)\.cross_attn\.o\.(.*)$": + r"blocks.\1.attn2.to_out.\2", + r"^blocks\.(\d+)\.cross_attn\.norm_q\.(.*)$": + r"blocks.\1.attn2.norm_q.\2", + r"^blocks\.(\d+)\.cross_attn\.norm_k\.(.*)$": + r"blocks.\1.attn2.norm_k.\2", + r"^blocks\.(\d+)\.cross_attn\.norm_q_img\.(.*)$": + r"blocks.\1.attn2.norm_added_q.\2", + r"^blocks\.(\d+)\.cross_attn\.norm_k_img\.(.*)$": + r"blocks.\1.attn2.norm_added_k.\2", + r"^blocks\.(\d+)\.ffn\.0\.(.*)$": + r"blocks.\1.ffn.fc_in.\2", + r"^blocks\.(\d+)\.ffn\.2\.(.*)$": + r"blocks.\1.ffn.fc_out.\2", + r"^blocks\.(\d+)\.norm3\.(.*)$": + r"blocks.\1.self_attn_residual_norm.norm.\2", + + # Diffusers-style naming. r"^condition_embedder\.text_embedder\.linear_1\.(.*)$": r"condition_embedder.text_embedder.fc_in.\1", r"^condition_embedder\.text_embedder\.linear_2\.(.*)$": @@ -74,11 +131,16 @@ class WanVideoArchConfig(DiTArchConfig): }) patch_size: tuple[int, int, int] = (1, 2, 2) - text_len = 512 + dim: int | None = None + text_len: int = 512 + num_heads: int | None = None num_attention_heads: int = 40 attention_head_dim: int = 128 + in_dim: int | None = None in_channels: int = 16 + out_dim: int | None = None out_channels: int = 16 + model_type: str | None = None text_dim: int = 4096 freq_dim: int = 256 ffn_dim: int = 13824 @@ -102,6 +164,15 @@ class WanVideoArchConfig(DiTArchConfig): sliding_window_num_frames: int = 21 def __post_init__(self): + if self.num_heads is not None: + self.num_attention_heads = self.num_heads + if self.in_dim is not None: + self.in_channels = self.in_dim + if self.out_dim is not None: + self.out_channels = self.out_dim + if self.dim is not None and self.num_attention_heads > 0: + self.attention_head_dim = self.dim // self.num_attention_heads + self.hidden_size = self.dim super().__post_init__() self.out_channels = self.out_channels or self.in_channels self.hidden_size = self.num_attention_heads * self.attention_head_dim diff --git a/fastvideo/models/loader/component_loader.py b/fastvideo/models/loader/component_loader.py index 8437e98214..b31ccf188f 100644 --- a/fastvideo/models/loader/component_loader.py +++ b/fastvideo/models/loader/component_loader.py @@ -28,6 +28,7 @@ from fastvideo.models.loader.fsdp_load import maybe_load_fsdp_model, shard_model from fastvideo.models.loader.utils import set_default_torch_dtype from fastvideo.models.loader.weight_utils import ( + extract_tensor_state_dict, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, pt_weights_iterator, @@ -152,7 +153,7 @@ def _prepare_weights( use_safetensors = False index_file = SAFE_WEIGHTS_INDEX_NAME - allow_patterns = ["*.safetensors", "*.bin"] + allow_patterns = ["*.safetensors", "*.bin", "*.pth"] if fall_back_to_pt: allow_patterns += ["*.pt"] @@ -244,11 +245,13 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): # model_override_args=None, # ) model_config = get_diffusers_config(model=model_path) + model_config.pop("_class_name", None) model_config.pop("_name_or_path", None) model_config.pop("transformers_version", None) model_config.pop("model_type", None) model_config.pop("tokenizer_class", None) model_config.pop("torch_dtype", None) + model_config.pop("architectures", None) repo_root = os.path.dirname(model_path) index_path = os.path.join(repo_root, "model_index.json") gemma_path = "" @@ -625,16 +628,44 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): vae_cls, _ = ModelRegistry.resolve_model_cls(class_name) vae = vae_cls(vae_config).to(target_device) - # Find all safetensors files + # Load weights from safetensors first, then fallback to pt/pth/bin. safetensors_list = glob.glob( os.path.join(str(model_path), "*.safetensors")) - if not safetensors_list: - raise ValueError(f"No safetensors files found in {model_path}") - # Common case: a single `.safetensors` checkpoint file. - # Some models may be sharded into multiple files; in that case we merge. loaded = {} - for sf_file in safetensors_list: - loaded.update(safetensors_load_file(sf_file)) + if safetensors_list: + for sf_file in safetensors_list: + loaded.update(safetensors_load_file(sf_file)) + else: + pt_files: list[str] = [] + for pattern in ("*.pt", "*.pth", "*.bin"): + pt_files = glob.glob(os.path.join(str(model_path), pattern)) + if pt_files: + break + if not pt_files: + raise ValueError( + f"No VAE weight files found in {model_path}. " + "Expected *.safetensors, *.pt, *.pth, or *.bin." + ) + for pt_file in pt_files: + try: + state = torch.load(pt_file, map_location="cpu", weights_only=True) + except Exception: + state = torch.load(pt_file, map_location="cpu", weights_only=False) + loaded.update(extract_tensor_state_dict(state)) + + # Some legacy checkpoints store VAE weights with an extra top-level prefix. + # Strip known prefixes only when there is no direct key overlap. + model_keys = set(vae.state_dict().keys()) + if model_keys and not (model_keys & set(loaded.keys())): + for prefix in ("vae.", "module.", "model.", "generator.", + "generator_ema."): + stripped = { + (k[len(prefix):] if k.startswith(prefix) else k): v + for k, v in loaded.items() + } + if model_keys & set(stripped.keys()): + loaded = stripped + break # LTX-2 CausalVideoAutoencoder needs per_channel_statistics remapping if class_name == "CausalVideoAutoencoder" and "vae" in config: @@ -751,6 +782,22 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): # Config from Diffusers supersedes fastvideo's model config dit_config = deepcopy(fastvideo_args.pipeline_config.dit_config) + if cls_name == "WanModel": + valid_fields = { + field.name for field in dataclasses.fields(dit_config.arch_config) + } + unknown_fields = sorted(set(config.keys()) - valid_fields) + if unknown_fields: + logger.warning( + "Ignoring %d unsupported WanModel config fields: %s", + len(unknown_fields), + ", ".join(unknown_fields), + ) + config = { + key: value + for key, value in config.items() + if key in valid_fields + } dit_config.update_model_arch(config) model_cls, _ = ModelRegistry.resolve_model_cls(cls_name) @@ -810,6 +857,7 @@ def load(self, model_path: str, fastvideo_args: FastVideoArgs): strict_load = not ( cls_name.startswith("Cosmos25") or cls_name == "Cosmos25Transformer3DModel" + or cls_name == "WanModel" or getattr(fastvideo_args.pipeline_config, "prefix", "") == "Cosmos25" ) model = maybe_load_fsdp_model( @@ -1005,4 +1053,4 @@ def load_module( ) # Load the module - return loader.load(component_model_path, fastvideo_args) \ No newline at end of file + return loader.load(component_model_path, fastvideo_args) diff --git a/fastvideo/models/loader/weight_utils.py b/fastvideo/models/loader/weight_utils.py index 3e6fa91ba8..5598d75920 100644 --- a/fastvideo/models/loader/weight_utils.py +++ b/fastvideo/models/loader/weight_utils.py @@ -151,11 +151,59 @@ def pt_weights_iterator( disable=not enable_tqdm, bar_format=_BAR_FORMAT, ): - state = torch.load(bin_file, map_location=device, weights_only=True) + try: + state = torch.load(bin_file, map_location=device, weights_only=True) + except Exception: + # Some legacy checkpoints contain objects unsupported by + # weights_only=True. Fall back for trusted local checkpoints. + state = torch.load(bin_file, map_location=device, weights_only=False) + state = extract_tensor_state_dict(state) yield from state.items() del state +def extract_tensor_state_dict(state: object) -> dict[str, torch.Tensor]: + """Extract a plain tensor state_dict from common checkpoint wrappers.""" + + def _is_tensor_dict(obj: object) -> bool: + return isinstance(obj, dict) and all( + isinstance(k, str) and isinstance(v, torch.Tensor) + for k, v in obj.items()) + + def _dfs_find_tensor_dict(obj: object, depth: int = 3) -> dict[str, torch.Tensor] | None: + if depth < 0: + return None + if _is_tensor_dict(obj): + return obj # type: ignore[return-value] + if not isinstance(obj, dict): + return None + for key in ( + "state_dict", + "model", + "module", + "model_state_dict", + "ema", + "generator_ema", + "weights", + ): + if key in obj: + found = _dfs_find_tensor_dict(obj[key], depth - 1) + if found is not None: + return found + for value in obj.values(): + found = _dfs_find_tensor_dict(value, depth - 1) + if found is not None: + return found + return None + + found = _dfs_find_tensor_dict(state) + if found is None: + raise ValueError( + "Failed to find a tensor state_dict in checkpoint. " + "Expected a dict[str, torch.Tensor] or a common wrapper around it.") + return found + + def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" diff --git a/fastvideo/models/registry.py b/fastvideo/models/registry.py index 350618805e..896526fb1e 100644 --- a/fastvideo/models/registry.py +++ b/fastvideo/models/registry.py @@ -42,6 +42,7 @@ _IMAGE_TO_VIDEO_DIT_MODELS = { # "HunyuanVideoTransformer3DModel": ("dits", "hunyuanvideo", "HunyuanVideoDiT"), "WanTransformer3DModel": ("dits", "wanvideo", "WanTransformer3DModel"), + "WanModel": ("dits", "wanvideo", "WanTransformer3DModel"), # new add "CausalWanTransformer3DModel": ("dits", "causal_wanvideo", "CausalWanTransformer3DModel"), "MatrixGameWanModel": ("dits", "matrixgame", "MatrixGameWanModel"), "CausalMatrixGameWanModel": ("dits", "matrixgame", "CausalMatrixGameWanModel"), diff --git a/fastvideo/pipelines/basic/wan/wan_cam_i2v_pipeline.py b/fastvideo/pipelines/basic/wan/wan_cam_i2v_pipeline.py new file mode 100644 index 0000000000..5ec2d0b7ec --- /dev/null +++ b/fastvideo/pipelines/basic/wan/wan_cam_i2v_pipeline.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +LingBot/Wan camera-control image-to-video pipeline. + +This pipeline targets Wan-style I2V checkpoints that provide two transformer +experts (high-noise + low-noise) but do not require CLIP image embeddings. +""" + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.models.schedulers.scheduling_flow_unipc_multistep import ( + FlowUniPCMultistepScheduler) +from fastvideo.pipelines import ComposedPipelineBase, LoRAPipeline +from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, + DenoisingStage, ImageVAEEncodingStage, + InputValidationStage, + LatentPreparationStage, + TextEncodingStage, + TimestepPreparationStage) + +logger = init_logger(__name__) + + +class WanCamImageToVideoPipeline(LoRAPipeline, ComposedPipelineBase): + + _required_config_modules = [ + "text_encoder", + "tokenizer", + "vae", + "transformer", + "transformer_2", + ] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + self.modules["scheduler"] = FlowUniPCMultistepScheduler( + shift=fastvideo_args.pipeline_config.flow_shift) + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + self.add_stage(stage_name="input_validation_stage", + stage=InputValidationStage()) + + self.add_stage(stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + )) + + self.add_stage(stage_name="conditioning_stage", + stage=ConditioningStage()) + + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) + + self.add_stage(stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"))) + + self.add_stage(stage_name="image_latent_preparation_stage", + stage=ImageVAEEncodingStage(vae=self.get_module("vae"))) + + self.add_stage(stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + transformer_2=self.get_module("transformer_2", None), + scheduler=self.get_module("scheduler"), + vae=self.get_module("vae"), + pipeline=self)) + + self.add_stage(stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"), + pipeline=self)) + + +EntryClass = WanCamImageToVideoPipeline diff --git a/fastvideo/registry.py b/fastvideo/registry.py index f6b7dc21c3..c31f2093f1 100644 --- a/fastvideo/registry.py +++ b/fastvideo/registry.py @@ -507,6 +507,17 @@ def _register_configs() -> None: "Wan-AI/Wan2.2-I2V-A14B-Diffusers", ], ) + register_configs( + sampling_param_cls=Wan2_2_I2V_A14B_SamplingParam, + pipeline_config_cls=Wan2_2_I2V_A14B_Config, + hf_model_paths=[ + "robbyant/lingbot-world-base-cam", + ], + model_detectors=[ + lambda path: "lingbot-world" in path.lower(), + lambda path: "wancamimagetovideopipeline" in path.lower(), + ], + ) register_configs( sampling_param_cls=SelfForcingWan2_1_T2V_1_3B_480P_SamplingParam, pipeline_config_cls=SelfForcingWanT2V480PConfig, diff --git a/fastvideo/worker/multiproc_executor.py b/fastvideo/worker/multiproc_executor.py index cc09c58982..c3083d9ac4 100644 --- a/fastvideo/worker/multiproc_executor.py +++ b/fastvideo/worker/multiproc_executor.py @@ -650,8 +650,10 @@ def worker_busy_loop(self) -> None: logging_info = None if envs.FASTVIDEO_STAGE_LOGGING: logging_info = output_batch.logging_info + # result tensor shared by CUDA IPC to avoid serialization overhead + result = output_batch.output self.pipe.send({ - "output_batch": output_batch.output.cpu(), + "output_batch": result, "logging_info": logging_info, "extra": output_batch.extra, }) @@ -739,4 +741,4 @@ def set_multiproc_executor_envs() -> None: in a multiprocessing environment. This should be called by the parent process before worker processes are created""" - force_spawn() + force_spawn() \ No newline at end of file diff --git a/scripts/checkpoint_conversion/lingbot_world_to_fastvideo.py b/scripts/checkpoint_conversion/lingbot_world_to_fastvideo.py new file mode 100644 index 0000000000..5278de1a9b --- /dev/null +++ b/scripts/checkpoint_conversion/lingbot_world_to_fastvideo.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +"""Prepare LingBot-World checkpoints into a FastVideo-compatible repo layout.""" + +from __future__ import annotations + +import argparse +import json +import shutil +from pathlib import Path + +UMT5_XXL_CONFIG = { + "vocab_size": 250112, + "d_model": 4096, + "d_kv": 64, + "d_ff": 10240, + "num_layers": 24, + "num_heads": 64, + "relative_attention_num_buckets": 32, + "relative_attention_max_distance": 128, + "dropout_rate": 0.1, + "layer_norm_epsilon": 1e-6, + "feed_forward_proj": "gated-gelu", + "is_encoder_decoder": True, + "use_cache": True, + "pad_token_id": 0, + "eos_token_id": 1, + "text_len": 512, +} + +WAN_VAE_CONFIG = { + "_class_name": "AutoencoderKLWan", + "_diffusers_version": "0.33.0", +} + + +def _safe_copytree(src: Path, dst: Path) -> None: + if dst.exists(): + shutil.rmtree(dst) + shutil.copytree(src, dst) + + +def _ensure_json(path: Path, content: dict) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="utf-8") as f: + json.dump(content, f, indent=2) + + +def _normalize_transformer_config(config_path: Path) -> None: + with config_path.open(encoding="utf-8") as f: + config = json.load(f) + config["_class_name"] = config.get("_class_name", "WanModel") + config["_diffusers_version"] = config.get("_diffusers_version", "0.33.0") + with config_path.open("w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + + +def _maybe_copy_shared_component(component: str, shared_root: Path | None, + output_root: Path) -> bool: + if shared_root is None: + return False + src = shared_root / component + if not src.exists(): + return False + _safe_copytree(src, output_root / component) + return True + + +def convert( + source_dir: Path, + output_dir: Path, + shared_components_dir: Path | None, + pipeline_class_name: str, + boundary_ratio: float, +) -> None: + high_noise = source_dir / "high_noise_model" + low_noise = source_dir / "low_noise_model" + + if not high_noise.exists() or not low_noise.exists(): + raise FileNotFoundError( + "source_dir must contain both high_noise_model/ and low_noise_model/" + ) + + output_dir.mkdir(parents=True, exist_ok=True) + + _safe_copytree(high_noise, output_dir / "transformer") + _safe_copytree(low_noise, output_dir / "transformer_2") + _normalize_transformer_config(output_dir / "transformer" / "config.json") + _normalize_transformer_config(output_dir / "transformer_2" / "config.json") + + copied_tokenizer = _maybe_copy_shared_component("tokenizer", + shared_components_dir, + output_dir) + if not copied_tokenizer: + tokenizer_src = source_dir / "google" / "umt5-xxl" + if not tokenizer_src.exists(): + raise FileNotFoundError( + "Missing tokenizer directory. Provide --shared-components-dir " + "or ensure source_dir/google/umt5-xxl exists.") + _safe_copytree(tokenizer_src, output_dir / "tokenizer") + + copied_text_encoder = _maybe_copy_shared_component("text_encoder", + shared_components_dir, + output_dir) + if not copied_text_encoder: + text_encoder_dir = output_dir / "text_encoder" + text_encoder_dir.mkdir(parents=True, exist_ok=True) + ckpt = source_dir / "models_t5_umt5-xxl-enc-bf16.pth" + if not ckpt.exists(): + raise FileNotFoundError( + "Missing text encoder checkpoint models_t5_umt5-xxl-enc-bf16.pth. " + "Provide --shared-components-dir with a prepared text_encoder/ " + "or place the checkpoint under source_dir.") + shutil.copy2(ckpt, text_encoder_dir / ckpt.name) + _ensure_json(text_encoder_dir / "config.json", UMT5_XXL_CONFIG) + + copied_vae = _maybe_copy_shared_component("vae", shared_components_dir, + output_dir) + if not copied_vae: + vae_dir = output_dir / "vae" + vae_dir.mkdir(parents=True, exist_ok=True) + vae_ckpt = source_dir / "Wan2.1_VAE.pth" + if not vae_ckpt.exists(): + raise FileNotFoundError( + "Missing VAE checkpoint Wan2.1_VAE.pth. " + "Provide --shared-components-dir with a prepared vae/ " + "or place the checkpoint under source_dir.") + shutil.copy2(vae_ckpt, vae_dir / vae_ckpt.name) + _ensure_json(vae_dir / "config.json", WAN_VAE_CONFIG) + + model_index = { + "_class_name": pipeline_class_name, + "_diffusers_version": "0.33.0", + "workload_type": "i2v", + "boundary_ratio": boundary_ratio, + "transformer": ["diffusers", "WanModel"], + "transformer_2": ["diffusers", "WanModel"], + "vae": ["diffusers", "AutoencoderKLWan"], + "text_encoder": ["transformers", "UMT5EncoderModel"], + "tokenizer": ["transformers", "AutoTokenizer"], + } + _ensure_json(output_dir / "model_index.json", model_index) + + print("Prepared FastVideo model repo at:", output_dir) + print("Pipeline:", pipeline_class_name) + print("Boundary ratio:", boundary_ratio) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert LingBot-World layout to FastVideo-compatible layout." + ) + parser.add_argument( + "--source-dir", + type=Path, + required=True, + help="Path to downloaded lingbot-world-base-cam directory.", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for FastVideo-compatible model layout.", + ) + parser.add_argument( + "--shared-components-dir", + type=Path, + default=None, + help=( + "Optional directory containing prepared text_encoder/, tokenizer/, " + "and vae/ folders (for example from a Wan2.2 Diffusers repo)." + ), + ) + parser.add_argument( + "--pipeline-class-name", + type=str, + default="WanCamImageToVideoPipeline", + help="Pipeline class name written into model_index.json.", + ) + parser.add_argument( + "--boundary-ratio", + type=float, + default=0.9, + help="boundary_ratio written into model_index.json for dual-transformer switching.", + ) + args = parser.parse_args() + + convert( + source_dir=args.source_dir, + output_dir=args.output_dir, + shared_components_dir=args.shared_components_dir, + pipeline_class_name=args.pipeline_class_name, + boundary_ratio=args.boundary_ratio, + ) + + +if __name__ == "__main__": + main()