diff --git a/docs/inference/support_matrix.md b/docs/inference/support_matrix.md index 931dd31618..c9d938e16e 100644 --- a/docs/inference/support_matrix.md +++ b/docs/inference/support_matrix.md @@ -48,6 +48,7 @@ pipeline initialization and sampling. | Model Name | HuggingFace Model ID | Resolutions | TeaCache | Sliding Tile Attn | Sage Attn | VSA | BSA | |------------|---------------------|-------------|----------|-------------------|-----------|-----|-----| +| Ovis-Image 7B | `AIDC-AI/Ovis-Image-7B` | 1024×1024 | ⭕ | ⭕ | ⭕ | ⭕ | ⭕ | | FastWan2.1 T2V 1.3B | `FastVideo/FastWan2.1-T2V-1.3B-Diffusers` | 480P | ⭕ | ⭕ | ⭕ | ✅ | ⭕ | | FastWan2.2 TI2V 5B Full Attn* | `FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers` | 720P | ⭕ | ⭕ | ⭕ | ✅ | ⭕ | | Wan2.2 TI2V 5B | `Wan-AI/Wan2.2-TI2V-5B-Diffusers` | 720P | ⭕ | ⭕ | ✅ | ⭕ | ⭕ | diff --git a/examples/inference/basic/basic_ovis_image.py b/examples/inference/basic/basic_ovis_image.py new file mode 100644 index 0000000000..39d59a5e0e --- /dev/null +++ b/examples/inference/basic/basic_ovis_image.py @@ -0,0 +1,104 @@ +""" +Ovis-Image Text-to-Image Generation Example + +This example demonstrates how to use the Ovis-Image-7B model for high-quality +text-to-image generation, especially for text rendering in images. + +Ovis-Image excels at: +- Text rendering in posters, banners, logos +- UI mockups with readable text +- Infographics with correct spelling +- Bilingual text rendering +""" + +from fastvideo import VideoGenerator + +OUTPUT_PATH = "ovis_image_samples" + + +def main(): + # Load Ovis-Image model + # Using local path to the downloaded model + generator = VideoGenerator.from_pretrained( + "AIDC-AI/Ovis-Image-7B", + # FastVideo will automatically handle distributed setup + num_gpus=1, + use_fsdp_inference=False, + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=False, # Qwen3 encoder + pin_cpu_memory=True, + ) + + # Example 1: Text rendering in a poster + prompt1 = ( + 'A creative 3D artistic render where the text "OVIS-IMAGE" is written ' + 'in a bold, expressive handwritten brush style using thick, wet oil paint. ' + 'The paint is a mix of vibrant rainbow colors (red, blue, yellow) swirling ' + 'together like toothpaste or impasto art. You can see the ridges of the brush ' + 'bristles and the glossy, wet texture of the paint. The background is a clean ' + "artist's canvas. Dynamic lighting creates soft shadows behind the floating " + 'paint strokes. Colorful, expressive, tactile texture, 4k detail.' + ) + + print(f"Generating image 1: Text rendering poster...") + image1 = generator.generate_video( + prompt1, + output_path=OUTPUT_PATH, + save_video=True, + num_frames=1, # Single image for T2I + height=1024, + width=1024, + num_inference_steps=50, + guidance_scale=5.0, + ) + + # Example 2: UI mockup with text + prompt2 = ( + 'A modern mobile app interface mockup showing a weather app. ' + 'At the top, display "Weather Today" in clean sans-serif font. ' + 'Below show the temperature "72°F" in large numbers. ' + 'Include labeled sections: "Humidity: 65%", "Wind: 12 mph", ' + 'and "Forecast: Sunny". Use a gradient blue background with ' + 'white text. Minimalist design, professional UI/UX, high resolution.' + ) + + print(f"Generating image 2: UI mockup...") + image2 = generator.generate_video( + prompt2, + output_path=OUTPUT_PATH, + save_video=True, + num_frames=1, # Single image for T2I + height=1024, + width=768, # Portrait orientation for mobile + num_inference_steps=50, + guidance_scale=5.0, + ) + + # Example 3: Logo with text + prompt3 = ( + 'A professional tech startup logo featuring the text "FAST AI" ' + 'in bold, modern geometric font. The letters are metallic silver ' + 'with a subtle blue glow effect. Below in smaller text: ' + '"Innovation through Technology". Clean white background, ' + 'minimalist design, corporate branding style, vector-like quality.' + ) + + print(f"Generating image 3: Logo design...") + image3 = generator.generate_video( + prompt3, + output_path=OUTPUT_PATH, + save_video=True, + num_frames=1, # Single image for T2I + height=512, + width=512, # Square for logo + num_inference_steps=50, + guidance_scale=5.0, + ) + + print(f"\nAll images saved to {OUTPUT_PATH}/") + print("Ovis-Image generation complete!") + + +if __name__ == "__main__": + main() diff --git a/fastvideo/configs/models/dits/__init__.py b/fastvideo/configs/models/dits/__init__.py index 13416ffc19..1ffb2c92dc 100644 --- a/fastvideo/configs/models/dits/__init__.py +++ b/fastvideo/configs/models/dits/__init__.py @@ -8,10 +8,11 @@ from fastvideo.configs.models.dits.stepvideo import StepVideoConfig from fastvideo.configs.models.dits.wanvideo import WanVideoConfig from fastvideo.configs.models.dits.hyworld import HYWorldConfig +from fastvideo.configs.models.dits.ovisimage import OvisImageTransformer2DModelConfig __all__ = [ "HunyuanVideoConfig", "HunyuanVideo15Config", "HunyuanGameCraftConfig", - "WanVideoConfig", "StepVideoConfig", "CosmosVideoConfig", - "Cosmos25VideoConfig", "LongCatVideoConfig", "LTX2VideoConfig", - "HYWorldConfig" + "WanVideoConfig", "StepVideoConfig", "CosmosVideoConfig", + "Cosmos25VideoConfig", "LongCatVideoConfig", "LTX2VideoConfig", + "HYWorldConfig", "OvisImageTransformer2DModelConfig" ] diff --git a/fastvideo/configs/models/dits/ovisimage.py b/fastvideo/configs/models/dits/ovisimage.py new file mode 100644 index 0000000000..8d547277d1 --- /dev/null +++ b/fastvideo/configs/models/dits/ovisimage.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Configuration for OvisImageTransformer2DModel""" + +from dataclasses import dataclass, field + +from fastvideo.configs.models.dits.base import DiTArchConfig, DiTConfig + + +def _is_double_block(n: str, m) -> bool: + """Match transformer_blocks.{i} (double-stream blocks).""" + return "transformer_blocks" in n and "single" not in n and str.isdigit( + n.split(".")[-1]) + + +def _is_single_block(n: str, m) -> bool: + """Match single_transformer_blocks.{i} (single-stream blocks).""" + return "single_transformer_blocks" in n and str.isdigit(n.split(".")[-1]) + + +@dataclass +class OvisImageTransformer2DModelArchConfig(DiTArchConfig): + """Architecture configuration for OvisImageTransformer2DModel.""" + + # Core architecture + hidden_size: int = 3072 # num_attention_heads * attention_head_dim = 24 * 128 + num_attention_heads: int = 24 + attention_head_dim: int = 128 + num_layers: int = 6 # Number of joint (double) layers + num_single_layers: int = 27 # Number of single layers + + # Input/output configuration + in_channels: int = 64 + out_channels: int | None = None # Can be None, defaults to in_channels + patch_size: int = 1 + + # Dimensions + joint_attention_dim: int = 2048 # Context dimension from text encoder + axes_dims_rope: list[int] = field(default_factory=lambda: [16, 56, 56]) + + # Legacy fields from base DiTArchConfig + num_channels_latents: int = 16 # VAE latent channels (in_channels=64 is packed=16*4) + + # FSDP: shard double and single transformer blocks + _fsdp_shard_conditions: list = field( + default_factory=lambda: [_is_double_block, _is_single_block]) + + # Compile: same as FSDP for now + _compile_conditions: list = field( + default_factory=lambda: [_is_double_block, _is_single_block]) + + # Weight name mapping: identity (native attrs match HF attrs) + param_names_mapping: dict = field(default_factory=dict) + reverse_param_names_mapping: dict = field(default_factory=dict) + lora_param_names_mapping: dict = field(default_factory=dict) + + +@dataclass +class OvisImageTransformer2DModelConfig(DiTConfig): + """Configuration for Ovis-Image DiT.""" + + arch_config: DiTArchConfig = field( + default_factory=OvisImageTransformer2DModelArchConfig) + prefix: str = "OvisImage" diff --git a/fastvideo/configs/models/encoders/__init__.py b/fastvideo/configs/models/encoders/__init__.py index 796b0bddfc..59b4ef22a5 100644 --- a/fastvideo/configs/models/encoders/__init__.py +++ b/fastvideo/configs/models/encoders/__init__.py @@ -7,6 +7,7 @@ from fastvideo.configs.models.encoders.llama import LlamaConfig from fastvideo.configs.models.encoders.t5 import T5Config, T5LargeConfig from fastvideo.configs.models.encoders.qwen2_5 import Qwen2_5_VLConfig +from fastvideo.configs.models.encoders.qwen3 import Qwen3ArchConfig, Qwen3Config from fastvideo.configs.models.encoders.siglip import SiglipVisionConfig from fastvideo.configs.models.encoders.reason1 import Reason1ArchConfig, Reason1Config from fastvideo.configs.models.encoders.gemma import LTX2GemmaConfig @@ -15,6 +16,6 @@ "EncoderConfig", "TextEncoderConfig", "ImageEncoderConfig", "BaseEncoderOutput", "CLIPTextConfig", "CLIPVisionConfig", "WAN2_1ControlCLIPVisionConfig", "LlamaConfig", "T5Config", "T5LargeConfig", - "Qwen2_5_VLConfig", "Reason1ArchConfig", "Reason1Config", "LTX2GemmaConfig", - "SiglipVisionConfig" + "Qwen2_5_VLConfig", "Qwen3ArchConfig", "Qwen3Config", "Reason1ArchConfig", + "Reason1Config", "LTX2GemmaConfig", "SiglipVisionConfig" ] diff --git a/fastvideo/configs/models/encoders/qwen3.py b/fastvideo/configs/models/encoders/qwen3.py new file mode 100644 index 0000000000..9fa66e661a --- /dev/null +++ b/fastvideo/configs/models/encoders/qwen3.py @@ -0,0 +1,99 @@ +# SPDX-License-Identifier: Apache-2.0 +from dataclasses import dataclass, field + +from fastvideo.configs.models.encoders.base import (TextEncoderArchConfig, + TextEncoderConfig) + + +def _is_transformer_layer(n: str, m) -> bool: + return "layers" in n and str.isdigit(n.split(".")[-1]) + + +def _is_embeddings(n: str, m) -> bool: + return n.endswith("embed_tokens") + + +def _is_final_norm(n: str, m) -> bool: + return n.endswith("norm") + + +@dataclass +class Qwen3ArchConfig(TextEncoderArchConfig): + """Architecture config for Qwen3 text encoder (used in Ovis-Image).""" + + # Model architecture - defaults from Ovis2.5-2B (Qwen3-2B) + vocab_size: int = 151936 + hidden_size: int = 2048 + intermediate_size: int = 6144 # Actual value from Ovis2.5-2B + num_hidden_layers: int = 28 # Actual value from Ovis2.5-2B + num_attention_heads: int = 16 + num_key_value_heads: int = 8 # Actual value from Ovis2.5-2B + hidden_act: str = "silu" + max_position_embeddings: int = 40960 # Actual value from Ovis2.5-2B + initializer_range: float = 0.02 + rms_norm_eps: float = 1e-06 + use_cache: bool = True + tie_word_embeddings: bool = True + rope_theta: float = 1000000.0 + rope_scaling: dict | None = None + use_sliding_window: bool = False + sliding_window: int | None = None # Can be None + max_window_layers: int = 28 # Actual value from Ovis2.5-2B + attention_dropout: float = 0.0 + attention_bias: bool = False + head_dim: int = 128 + + # HuggingFace transformers fields + bos_token_id: int = 151643 + eos_token_id: int = 151645 + dtype: str = "float32" + _attn_implementation_autoset: bool = True + layer_types: list[str] = field( + default_factory=lambda: ["full_attention"] * 28) + + # FastVideo-specific settings + hidden_state_skip_layer: int = 0 + text_len: int = 256 + + # Ovis-Image uses system prompt tokens (28 tokens) prepended to user tokens + user_prompt_begin_id: int = 28 + + # Qwen3-specific stacked params + stacked_params_mapping: list[tuple[str, str, str]] = field( + default_factory=lambda: [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), # type: ignore + (".gate_up_proj", ".up_proj", 1), # type: ignore + ]) + + _fsdp_shard_conditions: list = field( + default_factory=lambda: + [_is_transformer_layer, _is_embeddings, _is_final_norm]) + + def __post_init__(self): + super().__post_init__() + # Override tokenizer_kwargs for apply_chat_template + # Ovis-Image uses chat template with system prompt (28 tokens prepended) + # Total max_length = text_len + user_prompt_begin_id + self.tokenizer_kwargs = { + "add_generation_prompt": True, + "tokenize": True, + "return_dict": True, + "padding": "max_length", + "max_length": self.text_len + self.user_prompt_begin_id, + "truncation": True, + "return_tensors": "pt", + "enable_thinking": False, + } + + +@dataclass +class Qwen3Config(TextEncoderConfig): + """Configuration for Qwen3 text encoder.""" + + arch_config: TextEncoderArchConfig = field(default_factory=Qwen3ArchConfig) + prefix: str = "qwen3" + is_chat_model: bool = True diff --git a/fastvideo/configs/models/vaes/base.py b/fastvideo/configs/models/vaes/base.py index 7bff6d8239..8986470eaa 100644 --- a/fastvideo/configs/models/vaes/base.py +++ b/fastvideo/configs/models/vaes/base.py @@ -17,6 +17,26 @@ class VAEArchConfig(ArchConfig): temporal_compression_ratio: int = 4 spatial_compression_ratio: int = 8 + # Additional fields from diffusers AutoencoderKL + act_fn: str = "silu" + block_out_channels: list[int] = field( + default_factory=lambda: [128, 256, 512, 512]) + down_block_types: list[str] = field(default_factory=list) + up_block_types: list[str] = field(default_factory=list) + force_upcast: bool = False + in_channels: int = 3 + latent_channels: int = 16 + latents_mean: list[float] | None = None + latents_std: list[float] | None = None + layers_per_block: int = 2 + mid_block_add_attention: bool = True + norm_num_groups: int = 32 + out_channels: int = 3 + sample_size: int = 1024 + shift_factor: float | None = None + use_post_quant_conv: bool = False + use_quant_conv: bool = False + @dataclass class VAEConfig(ModelConfig): diff --git a/fastvideo/configs/ovis_image_7b_t2i_pipeline.json b/fastvideo/configs/ovis_image_7b_t2i_pipeline.json new file mode 100644 index 0000000000..0661779d1a --- /dev/null +++ b/fastvideo/configs/ovis_image_7b_t2i_pipeline.json @@ -0,0 +1,37 @@ +{ + "embedded_cfg_scale": 5.0, + "flow_shift": 3.0, + "dit_cpu_offload": false, + "disable_autocast": false, + "precision": "bf16", + "vae_precision": "fp32", + "vae_tiling": true, + "vae_sp": false, + "vae_config": { + "load_encoder": false, + "load_decoder": true, + "tile_sample_min_height": 256, + "tile_sample_min_width": 256, + "tile_sample_stride_height": 192, + "tile_sample_stride_width": 192, + "use_tiling": true, + "use_temporal_tiling": false, + "use_parallel_tiling": false, + "use_feature_cache": true + }, + "dit_config": { + "prefix": "OvisImage", + "quant_config": null + }, + "text_encoder_precisions": [ + "bf16" + ], + "text_encoder_configs": [ + { + "prefix": "qwen3", + "quant_config": null, + "lora_config": null + } + ], + "enable_torch_compile": false +} diff --git a/fastvideo/configs/pipelines/ovis_image.py b/fastvideo/configs/pipelines/ovis_image.py new file mode 100644 index 0000000000..65e0ea714f --- /dev/null +++ b/fastvideo/configs/pipelines/ovis_image.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Pipeline configuration for Ovis-Image text-to-image model.""" + +from collections.abc import Callable +from dataclasses import dataclass, field +from typing import Any + +import torch + +from fastvideo.configs.models import DiTConfig, EncoderConfig +from fastvideo.configs.models.dits import OvisImageTransformer2DModelConfig +from fastvideo.configs.models.encoders import BaseEncoderOutput, Qwen3Config +from fastvideo.configs.pipelines.base import PipelineConfig + +# System prompt from the Diffusers OvisImagePipeline +OVIS_SYSTEM_PROMPT = ( + "Describe the image by detailing the color, quantity, text, shape, size, " + "texture, spatial relationships of the objects and background: ") +# Number of tokens the system prompt + chat template special tokens occupy +USER_PROMPT_BEGIN_ID = 28 + + +def qwen3_preprocess_text(prompt: str) -> list[dict[str, Any]]: + """Format prompt as a chat message with system prompt for Qwen3. + + The Ovis-Image pipeline prepends a system prompt to guide text encoding, + formatted as a single user message for the chat template. + """ + return [{"role": "user", "content": OVIS_SYSTEM_PROMPT + prompt}] + + +def qwen3_postprocess_text( + outputs: BaseEncoderOutput, + mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Post-process Qwen3 encoder output for Ovis-Image. + + Applies attention mask to zero out padding tokens, then slices off + the system prompt tokens (first 28 tokens) from both embeddings and mask. + """ + prompt_embeds = outputs.last_hidden_state + # Zero out padding tokens + prompt_embeds = prompt_embeds * mask[..., None] + # Slice off system prompt tokens + prompt_embeds = prompt_embeds[:, USER_PROMPT_BEGIN_ID:] + mask = mask[:, USER_PROMPT_BEGIN_ID:] + return prompt_embeds, mask + + +@dataclass +class OvisImageT2IConfig(PipelineConfig): + """ + Configuration for Ovis-Image-7B text-to-image pipeline. + + Ovis-Image is optimized for high-quality text rendering in generated images. + This config uses Qwen3 (Ovis2.5-2B) as the text encoder. + """ + + # Denoising stage + embedded_cfg_scale: float = 5.0 + flow_shift: float = 3.0 + + # DiT configuration + dit_config: DiTConfig = field( + default_factory=OvisImageTransformer2DModelConfig) + + # Text encoding stage + text_encoder_configs: tuple[EncoderConfig, ...] = field( + default_factory=lambda: (Qwen3Config(), )) + preprocess_text_funcs: tuple[Callable[[str], list[dict[str, Any]]], + ...] = field(default_factory=lambda: + (qwen3_preprocess_text, )) + postprocess_text_funcs: tuple[Callable[[Any, Any], tuple[Any, Any]], + ...] = field(default_factory=lambda: + (qwen3_postprocess_text, )) + + # Precision for each component + dit_precision: str = "bf16" + vae_precision: str = "fp32" + text_encoder_precisions: tuple[str, ...] = field( + default_factory=lambda: ("bf16", )) + + def __post_init__(self): + """Configure VAE for decoder-only mode.""" + # Since VAEConfig may be set via kwargs, check and configure + if hasattr(self, 'vae_config') and self.vae_config is not None: + self.vae_config.load_encoder = False + self.vae_config.load_decoder = True diff --git a/fastvideo/models/dits/ovisimage.py b/fastvideo/models/dits/ovisimage.py new file mode 100644 index 0000000000..c29cdaaaf6 --- /dev/null +++ b/fastvideo/models/dits/ovisimage.py @@ -0,0 +1,753 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Ovis-Image Transformer2D Model — Approach B (FastVideo-native implementation) + +Architecture: FLUX-like MM-DiT with: + - 6 double-stream (joint) transformer blocks + - 27 single-stream transformer blocks + - Qwen3 text encoder (2048-dim) projected to shared 3072-dim hidden space + - FLUX-style 3D RoPE with axes_dims_rope=[16, 56, 56] + - FastVideo DistributedAttention for SP support + - ReplicatedLinear layers (FSDP-compatible, TP-ready) + - CachableDiT base for TeaCache optimization + +Weight attribute names match Diffusers OvisImageTransformer2DModel exactly, +so param_names_mapping = {} and weights load without any remapping. +""" + +import math +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fastvideo.attention import DistributedAttention +from fastvideo.configs.models.dits import OvisImageTransformer2DModelConfig +from fastvideo.configs.models.dits.base import DiTConfig +from fastvideo.distributed.communication_op import ( + sequence_model_parallel_all_gather, sequence_model_parallel_shard) +from fastvideo.forward_context import get_forward_context +from fastvideo.layers.layernorm import RMSNorm +from fastvideo.layers.linear import ReplicatedLinear +from fastvideo.models.dits.base import CachableDiT +from fastvideo.platforms import AttentionBackendEnum +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + +# --------------------------------------------------------------------------- +# Helpers: latent packing / unpacking and position IDs +# --------------------------------------------------------------------------- + + +def _pack_latents(latents: torch.Tensor) -> torch.Tensor: + """Pack [B, C, H, W] -> [B, (H/2)*(W/2), C*4] for Ovis-Image transformer.""" + B, C, H, W = latents.shape + latents = latents.view(B, C, H // 2, 2, W // 2, 2) + latents = latents.permute(0, 2, 4, 1, 3, 5) + return latents.reshape(B, (H // 2) * (W // 2), C * 4) + + +def _unpack_latents(latents: torch.Tensor, H: int, W: int) -> torch.Tensor: + """Unpack [B, (H/2)*(W/2), C*4] -> [B, C, H, W].""" + B, _, channels = latents.shape + C = channels // 4 + latents = latents.view(B, H // 2, W // 2, C, 2, 2) + latents = latents.permute(0, 3, 1, 4, 2, 5) + return latents.reshape(B, C, H, W) + + +def _prepare_img_ids(H_half: int, W_half: int, + device: torch.device) -> torch.Tensor: + """Image position IDs for RoPE: [H_half*W_half, 3] with (0, row, col).""" + ids = torch.zeros(H_half, W_half, 3, device=device) + ids[..., 1] = torch.arange(H_half, device=device)[:, None] + ids[..., 2] = torch.arange(W_half, device=device)[None, :] + return ids.reshape(H_half * W_half, 3) + + +def _prepare_txt_ids(seq_len: int, device: torch.device) -> torch.Tensor: + """Text position IDs for RoPE: [seq_len, 3] with (0, i, i).""" + ids = torch.zeros(seq_len, 3, device=device) + ids[:, 1] = torch.arange(seq_len, device=device) + ids[:, 2] = torch.arange(seq_len, device=device) + return ids + + +# --------------------------------------------------------------------------- +# FLUX-style RoPE +# --------------------------------------------------------------------------- + + +class OvisImageRoPE(nn.Module): + """ + FLUX-style 3D RoPE for Ovis-Image. + + Splits head_dim across three axes according to axes_dims_rope, computing + separate frequency tables per axis and concatenating them. + + Position IDs: [..., 3] where components index (axis0, axis1, axis2). + For 2D images: axis0=0, axis1=row, axis2=col. + """ + + def __init__(self, head_dim: int, axes_dims: list[int], + theta: float = 10000.0): + super().__init__() + assert sum(axes_dims) == head_dim, ( + f"sum(axes_dims)={sum(axes_dims)} != head_dim={head_dim}") + self.head_dim = head_dim + self.axes_dims = axes_dims + self.theta = theta + + def _freqs_for_axis(self, dim: int, + device: torch.device) -> torch.Tensor: + """Inverse frequency vector for one axis dimension.""" + half = dim // 2 + return 1.0 / (self.theta ** ( + torch.arange(0, half, device=device, dtype=torch.float32) / half)) + + def forward(self, ids: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + ids: position IDs [..., 3] + Returns: + (cos, sin) each [..., head_dim] + """ + cos_parts, sin_parts = [], [] + for axis_idx, dim in enumerate(self.axes_dims): + pos = ids[..., axis_idx].float() + inv_freq = self._freqs_for_axis(dim, ids.device) + freqs = torch.outer(pos.reshape(-1), + inv_freq).reshape(*pos.shape, -1) + # Interleaved pairs [θ0,θ0,θ1,θ1,...] — matches Diffusers repeat_interleave_real=True + emb = freqs.repeat_interleave(2, dim=-1) + cos_parts.append(emb.cos()) + sin_parts.append(emb.sin()) + + return torch.cat(cos_parts, dim=-1), torch.cat(sin_parts, dim=-1) + + +def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings — pair-wise complex rotation matching Diffusers + apply_rotary_emb with use_real_unbind_dim=-1. + + q, k: [B, seq, n_heads, head_dim] + cos, sin: [seq, head_dim] (interleaved pairs encoding) + """ + cos = cos.unsqueeze(0).unsqueeze(2) # [1, seq, 1, head_dim] + sin = sin.unsqueeze(0).unsqueeze(2) + + # Unbind adjacent pairs: x -> (x_real, x_imag) each [..., head_dim//2] + q_r, q_i = q.reshape(*q.shape[:-1], -1, 2).unbind(-1) + k_r, k_i = k.reshape(*k.shape[:-1], -1, 2).unbind(-1) + # Rotate: [-imag, real] — matches torch.stack([-x_imag, x_real]).flatten + q_rot = torch.stack([-q_i, q_r], dim=-1).flatten(-2) + k_rot = torch.stack([-k_i, k_r], dim=-1).flatten(-2) + + return (q.float() * cos + q_rot.float() * sin).to(q.dtype), ( + k.float() * cos + k_rot.float() * sin).to(k.dtype) + + +# --------------------------------------------------------------------------- +# Adaptive layer norms +# --------------------------------------------------------------------------- + + +class OvisAdaLayerNormZero(nn.Module): + """ + Adaptive LayerNorm with zero-initialized modulation for double stream blocks. + Produces 6 values: (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp). + """ + + def __init__(self, hidden_size: int): + super().__init__() + self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, 6 * hidden_size, bias=True) + + def forward( + self, x: torch.Tensor, c: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor]: + emb = self.linear(F.silu(c)) + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + emb.chunk(6, dim=-1)) + # c is [B, hidden]; x is [B, seq, hidden] — unsqueeze for broadcast + x_norm = self.norm(x) * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1) + return x_norm, gate_msa.unsqueeze(1), shift_mlp.unsqueeze(1), scale_mlp.unsqueeze(1), gate_mlp.unsqueeze(1) + + +class OvisAdaLayerNormZeroSingle(nn.Module): + """ + Adaptive LayerNorm with zero-initialized modulation for single stream blocks. + Produces 3 values: (shift, scale, gate). + """ + + def __init__(self, hidden_size: int): + super().__init__() + self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, 3 * hidden_size, bias=True) + + def forward( + self, x: torch.Tensor, c: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + emb = self.linear(F.silu(c)) + shift, scale, gate = emb.chunk(3, dim=-1) + # c is [B, hidden]; x is [B, seq, hidden] — unsqueeze for broadcast + x_norm = self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + return x_norm, gate.unsqueeze(1) + + +class OvisAdaLayerNormContinuous(nn.Module): + """Final adaptive LayerNorm before output projection.""" + + def __init__(self, hidden_size: int): + super().__init__() + self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, 2 * hidden_size, bias=True) + + def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor: + emb = self.linear(F.silu(c)) + scale, shift = emb.chunk(2, dim=-1) + # c is [B, hidden]; x is [B, seq, hidden] — unsqueeze for broadcast + return self.norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +# --------------------------------------------------------------------------- +# GEGLU feed-forward +# --------------------------------------------------------------------------- + + +class OvisGEGLUFeedForward(nn.Module): + """ + GEGLU feed-forward used in double stream blocks. + Attribute names match Diffusers ff.net structure for weight compatibility. + """ + + def __init__(self, hidden_size: int, ff_dim: int): + super().__init__() + # Diffusers stores as ff.net[0].proj (GEGLU: gate+up fused) and ff.net[2] + self.net = nn.ModuleList([ + _GEGLUGateUp(hidden_size, ff_dim), # index 0 + nn.Identity(), # index 1 (dropout placeholder) + nn.Linear(ff_dim, hidden_size, bias=True), # index 2 + ]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.net[0](x) # GEGLU + out = self.net[2](x) # down projection + return out + + +class _GEGLUGateUp(nn.Module): + """SwiGLU gate+up projection matching Diffusers FeedForward(activation_fn='swiglu'). + + Diffusers SwiGLU: proj -> [hidden, gate], return hidden * silu(gate). + """ + + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.proj = nn.Linear(in_features, out_features * 2, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hidden, gate = self.proj(x).chunk(2, dim=-1) + return hidden * F.silu(gate) + + +# --------------------------------------------------------------------------- +# Attention sub-modules (attr names match Diffusers for weight loading) +# --------------------------------------------------------------------------- + + +class _OvisDoubleAttn(nn.Module): + """ + Joint attention for double stream blocks. + Attribute names mirror Diffusers: to_q, to_k, to_v, to_out, + add_q_proj, add_k_proj, add_v_proj, to_add_out, norm_q/k, norm_added_q/k. + """ + + def __init__(self, hidden_size: int, num_heads: int, head_dim: int, + supported_attention_backends, prefix: str): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + # Image QKV + output + self.to_q = nn.Linear(hidden_size, hidden_size, bias=True) + self.to_k = nn.Linear(hidden_size, hidden_size, bias=True) + self.to_v = nn.Linear(hidden_size, hidden_size, bias=True) + self.to_out = nn.ModuleList( + [nn.Linear(hidden_size, hidden_size, bias=True)]) + + # Text QKV + output + self.add_q_proj = nn.Linear(hidden_size, hidden_size, bias=True) + self.add_k_proj = nn.Linear(hidden_size, hidden_size, bias=True) + self.add_v_proj = nn.Linear(hidden_size, hidden_size, bias=True) + self.to_add_out = nn.Linear(hidden_size, hidden_size, bias=True) + + # QK-Norm + self.norm_q = RMSNorm(head_dim, eps=1e-6) + self.norm_k = RMSNorm(head_dim, eps=1e-6) + self.norm_added_q = RMSNorm(head_dim, eps=1e-6) + self.norm_added_k = RMSNorm(head_dim, eps=1e-6) + + # Distributed attention (SP-aware) + self.attn_op = DistributedAttention( + num_heads=num_heads, + head_size=head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn_op") + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + img_cos: torch.Tensor, + img_sin: torch.Tensor, + txt_cos: torch.Tensor, + txt_sin: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + B, img_seq = img.shape[:2] + txt_seq = txt.shape[1] + + # Image QKV + img_q = self.to_q(img).view(B, img_seq, self.num_heads, self.head_dim) + img_k = self.to_k(img).view(B, img_seq, self.num_heads, self.head_dim) + img_v = self.to_v(img).view(B, img_seq, self.num_heads, self.head_dim) + img_q = self.norm_q(img_q).to(img_v.dtype) + img_k = self.norm_k(img_k).to(img_v.dtype) + img_q, img_k = _apply_rope(img_q, img_k, img_cos, img_sin) + + # Text QKV + txt_q = self.add_q_proj(txt).view(B, txt_seq, self.num_heads, + self.head_dim) + txt_k = self.add_k_proj(txt).view(B, txt_seq, self.num_heads, + self.head_dim) + txt_v = self.add_v_proj(txt).view(B, txt_seq, self.num_heads, + self.head_dim) + txt_q = self.norm_added_q(txt_q).to(txt_v.dtype) + txt_k = self.norm_added_k(txt_k).to(txt_v.dtype) + txt_q, txt_k = _apply_rope(txt_q, txt_k, txt_cos, txt_sin) + + # Joint attention via DistributedAttention + img_attn, txt_attn = self.attn_op(img_q, img_k, img_v, txt_q, txt_k, + txt_v) + + img_out = self.to_out[0](img_attn.reshape(B, img_seq, -1)) + txt_out = self.to_add_out(txt_attn.reshape(B, txt_seq, -1)) + return img_out, txt_out + + +class _OvisSingleAttn(nn.Module): + """ + Single-stream attention (merged image+text). + Attribute names match Diffusers for weight loading. + """ + + def __init__(self, hidden_size: int, num_heads: int, head_dim: int, + supported_attention_backends, prefix: str): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + self.to_q = nn.Linear(hidden_size, hidden_size, bias=True) + self.to_k = nn.Linear(hidden_size, hidden_size, bias=True) + self.to_v = nn.Linear(hidden_size, hidden_size, bias=True) + self.norm_q = RMSNorm(head_dim, eps=1e-6) + self.norm_k = RMSNorm(head_dim, eps=1e-6) + + self.attn_op = DistributedAttention( + num_heads=num_heads, + head_size=head_dim, + causal=False, + supported_attention_backends=supported_attention_backends, + prefix=f"{prefix}.attn_op") + + def forward(self, x: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor) -> torch.Tensor: + B, seq = x.shape[:2] + q = self.to_q(x).view(B, seq, self.num_heads, self.head_dim) + k = self.to_k(x).view(B, seq, self.num_heads, self.head_dim) + v = self.to_v(x).view(B, seq, self.num_heads, self.head_dim) + q = self.norm_q(q).to(v.dtype) + k = self.norm_k(k).to(v.dtype) + q, k = _apply_rope(q, k, cos, sin) + # Single stream: pass img+txt jointly, no split + attn_out, _ = self.attn_op(q, k, v, None, None, None) + return attn_out.reshape(B, seq, -1) + + +# --------------------------------------------------------------------------- +# Double stream block +# --------------------------------------------------------------------------- + + +class OvisImageDoubleStreamBlock(nn.Module): + """ + FLUX-style joint (double-stream) transformer block. + + Image and text each have their own adaptive LayerNorm and FFN, but share + a joint cross-attention. Attribute names match Diffusers for weight compat. + """ + + def __init__(self, hidden_size: int, num_heads: int, head_dim: int, + ff_dim: int, supported_attention_backends, prefix: str): + super().__init__() + + # Image stream + self.norm1 = OvisAdaLayerNormZero(hidden_size) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.ff = OvisGEGLUFeedForward(hidden_size, ff_dim) + + # Text stream + self.norm1_context = OvisAdaLayerNormZero(hidden_size) + self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, + eps=1e-6) + self.ff_context = OvisGEGLUFeedForward(hidden_size, ff_dim) + + # Joint attention + self.attn = _OvisDoubleAttn(hidden_size, num_heads, head_dim, + supported_attention_backends, + prefix=f"{prefix}.attn") + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + vec: torch.Tensor, + img_cos: torch.Tensor, + img_sin: torch.Tensor, + txt_cos: torch.Tensor, + txt_sin: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Adaptive norms + img_n, img_gate_msa, img_shift_mlp, img_scale_mlp, img_gate_mlp = ( + self.norm1(img, vec)) + txt_n, txt_gate_msa, txt_shift_mlp, txt_scale_mlp, txt_gate_mlp = ( + self.norm1_context(txt, vec)) + + # Joint attention + img_attn, txt_attn = self.attn(img_n, txt_n, img_cos, img_sin, txt_cos, + txt_sin) + + # Image: residual + norm + MLP + img = img + img_gate_msa * img_attn + img_ff_in = self.norm2(img) * (1 + img_scale_mlp) + img_shift_mlp + img = img + img_gate_mlp * self.ff(img_ff_in) + + # Text: residual + norm + MLP + txt = txt + txt_gate_msa * txt_attn + txt_ff_in = self.norm2_context(txt) * (1 + txt_scale_mlp) + txt_shift_mlp + txt = txt + txt_gate_mlp * self.ff_context(txt_ff_in) + + return img, txt + + +# --------------------------------------------------------------------------- +# Single stream block +# --------------------------------------------------------------------------- + + +class OvisImageSingleStreamBlock(nn.Module): + """ + FLUX-style single-stream transformer block. + + Receives img and txt separately, concatenates txt-first internally (matching + Diffusers), processes jointly, then splits and returns (txt, img). + Attribute names match Diffusers for weight compat. + """ + + def __init__(self, hidden_size: int, num_heads: int, head_dim: int, + mlp_ratio: float, supported_attention_backends, prefix: str): + super().__init__() + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + + self.norm = OvisAdaLayerNormZeroSingle(hidden_size) + self.attn = _OvisSingleAttn(hidden_size, num_heads, head_dim, + supported_attention_backends, + prefix=f"{prefix}.attn") + # proj_mlp outputs 2 × mlp_hidden_dim for SiLU gating (matching Diffusers) + self.proj_mlp = nn.Linear(hidden_size, self.mlp_hidden_dim * 2, bias=True) + self.proj_out = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, + bias=True) + + def forward( + self, + img: torch.Tensor, + txt: torch.Tensor, + temb: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + txt_seq = txt.shape[1] + # Concat txt first then img (matching Diffusers convention) + x = torch.cat([txt, img], dim=1) + residual = x + + x_norm, gate = self.norm(x, temb) + # SiLU-gated MLP (not GeLU): proj_mlp → [hidden, gate], silu(gate) * hidden + mlp_out, mlp_gate = self.proj_mlp(x_norm).chunk(2, dim=-1) + mlp_out = F.silu(mlp_gate) * mlp_out + attn_out = self.attn(x_norm, cos, sin) + combined = torch.cat([attn_out, mlp_out], dim=-1) + x = residual + gate * self.proj_out(combined) + + txt_out = x[:, :txt_seq] + img_out = x[:, txt_seq:] + return txt_out, img_out + + +# --------------------------------------------------------------------------- +# Timestep + pooled text conditioning +# --------------------------------------------------------------------------- + + +def _timestep_embedding(t: torch.Tensor, dim: int, + max_period: int = 10000) -> torch.Tensor: + """Sinusoidal timestep embedding.""" + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * + torch.arange(0, half, dtype=torch.float32, device=t.device) / half) + args = t[:, None].float() * freqs[None] + return torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + +class OvisTimestepEmbedder(nn.Module): + """ + Timestep MLP matching Diffusers TimestepEmbedding weight names. + + Checkpoint keys: timestep_embedder.linear_1, timestep_embedder.linear_2 + """ + + def __init__(self, in_channels: int, hidden_size: int): + super().__init__() + self.linear_1 = nn.Linear(in_channels, hidden_size, bias=True) + self.linear_2 = nn.Linear(hidden_size, hidden_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_2(F.silu(self.linear_1(x))) + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + +_CFG = OvisImageTransformer2DModelConfig() + + +class OvisImageTransformer2DModel(CachableDiT): + """ + Native FastVideo implementation of the Ovis-Image diffusion transformer. + + Architecture: FLUX-like MM-DiT + - 6 double-stream (joint) blocks (transformer_blocks) + - 27 single-stream blocks (single_transformer_blocks) + - FLUX-style 3D RoPE (axes_dims_rope=[16, 56, 56]) + - FastVideo DistributedAttention (SP-compatible) + - CachableDiT base (TeaCache-ready) + + Weight names match Diffusers OvisImageTransformer2DModel exactly + => param_names_mapping = {} (no remapping, weights load directly). + """ + + # ---- Required CachableDiT class attributes ---- + _fsdp_shard_conditions = _CFG._fsdp_shard_conditions + _compile_conditions: list = [] + _supported_attention_backends: tuple[ + AttentionBackendEnum, ...] = _CFG._supported_attention_backends + param_names_mapping: dict = {} + reverse_param_names_mapping: dict = {} + lora_param_names_mapping: dict = {} + + def __init__(self, config: DiTConfig, hf_config: dict[str, Any], + **kwargs) -> None: + super().__init__(config=config, hf_config=hf_config) + + arch = config.arch_config + hidden_size: int = arch.hidden_size + num_heads: int = arch.num_attention_heads + head_dim: int = arch.attention_head_dim + num_layers: int = arch.num_layers + num_single_layers: int = arch.num_single_layers + in_channels: int = arch.in_channels + out_channels: int = (arch.out_channels + if arch.out_channels is not None else in_channels) + joint_attention_dim: int = arch.joint_attention_dim + + # FastVideo required instance attributes + self.hidden_size = hidden_size + self.num_attention_heads = num_heads + self.num_channels_latents = arch.num_channels_latents # 16 (VAE latent ch) + self.out_channels = out_channels + self.in_channels = in_channels + + ff_dim = hidden_size * 4 # standard MLP ratio + + # Input projections (weight names match Diffusers) + self.x_embedder = nn.Linear(in_channels, hidden_size, bias=True) + # Norm applied to text encoder output before projection (matches Diffusers) + self.context_embedder_norm = RMSNorm(joint_attention_dim, eps=1e-6) + self.context_embedder = nn.Linear(joint_attention_dim, hidden_size, + bias=True) + + # Timestep conditioning (purely from timestep, no pooled text) + # Matches Diffusers: timestep_embedder.linear_1 / timestep_embedder.linear_2 + self._freq_dim = 256 + self.timestep_embedder = OvisTimestepEmbedder(self._freq_dim, hidden_size) + + # Transformer blocks + self.transformer_blocks = nn.ModuleList([ + OvisImageDoubleStreamBlock( + hidden_size=hidden_size, + num_heads=num_heads, + head_dim=head_dim, + ff_dim=ff_dim, + supported_attention_backends=self._supported_attention_backends, + prefix=f"transformer_blocks.{i}", + ) for i in range(num_layers) + ]) + + self.single_transformer_blocks = nn.ModuleList([ + OvisImageSingleStreamBlock( + hidden_size=hidden_size, + num_heads=num_heads, + head_dim=head_dim, + mlp_ratio=4.0, + supported_attention_backends=self._supported_attention_backends, + prefix=f"single_transformer_blocks.{i}", + ) for i in range(num_single_layers) + ]) + + # Output (weight names match Diffusers) + self.norm_out = OvisAdaLayerNormContinuous(hidden_size) + self.proj_out = nn.Linear(hidden_size, out_channels, bias=True) + + # RoPE + self.rope = OvisImageRoPE( + head_dim=head_dim, + axes_dims=list(arch.axes_dims_rope), + ) + + self.__post_init__() + + # ------------------------------------------------------------------ + # Forward + # ------------------------------------------------------------------ + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | list[torch.Tensor], + timestep: torch.LongTensor, + encoder_hidden_states_image: torch.Tensor | list[torch.Tensor] + | None = None, + guidance=None, + **kwargs, + ) -> torch.Tensor: + """ + Args: + hidden_states: [B, C, T, H, W] or [B, C, H, W] latents + encoder_hidden_states: [B, txt_seq, joint_attention_dim] or list + timestep: [B] diffusion timestep + Returns: + Denoised latents in same shape as hidden_states + """ + if isinstance(encoder_hidden_states, list): + encoder_hidden_states = encoder_hidden_states[0] + + had_temporal = hidden_states.ndim == 5 + if had_temporal: + hidden_states = hidden_states.squeeze(2) + + B, C, H, W = hidden_states.shape + + # Pack latents: [B, C, H, W] -> [B, img_seq, C*4] + img_latents = _pack_latents(hidden_states) + + # Project to hidden_size + img = self.x_embedder(img_latents) # [B, img_seq, hidden_size] + # Apply RMSNorm before projecting text (matches Diffusers context_embedder_norm) + enc_norm = self.context_embedder_norm(encoder_hidden_states) + txt = self.context_embedder(enc_norm) # [B, txt_seq, hidden_size] + + txt_seq = txt.shape[1] + img_seq = img.shape[1] + + # Timestep-only conditioning (Diffusers: timestep * 1000 then sinusoidal) + # Input timestep is in [0, 1000]; sinusoidal embedding is computed at that scale + t_emb = _timestep_embedding(timestep, self._freq_dim).to(img.dtype) + temb = self.timestep_embedder(t_emb) # [B, hidden_size] + + # RoPE position IDs — joint sequence: txt first, img second (matches Diffusers) + img_ids = kwargs.get("img_ids") + txt_ids = kwargs.get("txt_ids") + if img_ids is None: + img_ids = _prepare_img_ids(H // 2, W // 2, hidden_states.device) + if txt_ids is None: + txt_ids = _prepare_txt_ids(txt_seq, hidden_states.device) + + # Joint RoPE (txt first, then img) + joint_ids = torch.cat([txt_ids, img_ids], dim=0) + joint_cos, joint_sin = self.rope(joint_ids) + joint_cos = joint_cos.to(img.dtype) + joint_sin = joint_sin.to(img.dtype) + txt_cos = joint_cos[:txt_seq] + txt_sin = joint_sin[:txt_seq] + img_cos = joint_cos[txt_seq:] + img_sin = joint_sin[txt_seq:] + + # TeaCache early exit check + forward_context = get_forward_context() + forward_batch = getattr(forward_context, "forward_batch", None) + enable_teacache = (forward_batch is not None + and getattr(forward_batch, "enable_teacache", False)) + if enable_teacache: + original_img = img.clone() + + # Sequence Parallelism: shard image sequence across SP ranks + img, _ = sequence_model_parallel_shard(img, dim=1) + + # Double-stream blocks: temb as conditioning + for block in self.transformer_blocks: + img, txt = block(img, txt, temb, img_cos, img_sin, txt_cos, txt_sin) + + # Single-stream blocks: blocks handle txt/img concat internally (txt first) + for block in self.single_transformer_blocks: + txt, img = block(img, txt, temb, joint_cos, joint_sin) + + # Gather SP shards + img = sequence_model_parallel_all_gather(img, dim=1) + + if enable_teacache: + self.maybe_cache_states(img, original_img) + + # Output (img stream only) + img = self.norm_out(img, temb) + img = self.proj_out(img) + + output = _unpack_latents(img, H, W) + if had_temporal: + output = output.unsqueeze(2) + + return output + + # ------------------------------------------------------------------ + # TeaCache interface (CachableDiT) + # ------------------------------------------------------------------ + + def maybe_cache_states(self, hidden_states: torch.Tensor, + original_hidden_states: torch.Tensor) -> None: + """Cache residual between current and original hidden states.""" + self.previous_resiual = hidden_states - original_hidden_states + + def should_skip_forward_for_cached_states(self, **kwargs) -> bool: + """TeaCache skip decision — not yet calibrated for Ovis-Image.""" + forward_context = get_forward_context() + forward_batch = getattr(forward_context, "forward_batch", None) + if forward_batch is None: + return False + return False # Always compute for now; calibrate coefficients later diff --git a/fastvideo/models/encoders/qwen3.py b/fastvideo/models/encoders/qwen3.py new file mode 100644 index 0000000000..47033247ac --- /dev/null +++ b/fastvideo/models/encoders/qwen3.py @@ -0,0 +1,456 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Qwen3 Text Encoder — Approach B (FastVideo-native implementation) + +Architecture (Qwen3-2B as used in Ovis-Image-7B): + - Standard transformer decoder (no MLLM vision, pure text) + - GQA attention (16 Q heads, 8 KV heads) + - QK-Norm: RMSNorm applied to Q and K before attention (Qwen3 specific) + - SwiGLU MLP via MergedColumnParallelLinear + RowParallelLinear + - RoPE position embeddings (standard, not multi-modal) + - Tensor Parallelism via QKVParallelLinear / RowParallelLinear + - Quantization support via quant_config + - Proper weight loading with stacked_params_mapping + +Adapted from fastvideo/models/encoders/qwen2_5.py with the following deltas: + 1. QK-Norm after Q/K split (Qwen3 adds q_norm, k_norm per attention layer) + 2. Standard RoPE (not multi-modal / mrope) + 3. attention_bias=False by default +""" + +import math +from typing import Iterable, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from fastvideo.configs.models.encoders import BaseEncoderOutput, Qwen3Config +from fastvideo.distributed import get_tp_rank, get_tp_world_size +from fastvideo.layers.activation import SiluAndMul +from fastvideo.layers.layernorm import RMSNorm +from fastvideo.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) +from fastvideo.layers.quantization import QuantizationConfig +from fastvideo.layers.vocab_parallel_embedding import VocabParallelEmbedding +from fastvideo.models.encoders.base import TextEncoder +from fastvideo.models.loader.weight_utils import default_weight_loader +from fastvideo.models.mask_utils import sdpa_mask +from fastvideo.logger import init_logger + +logger = init_logger(__name__) + + +# --------------------------------------------------------------------------- +# RoPE +# --------------------------------------------------------------------------- + + +class Qwen3RotaryEmbedding(nn.Module): + """Standard RoPE for Qwen3 (non-multimodal).""" + + def __init__(self, config: Qwen3Config, device=None): + super().__init__() + arch = config.arch_config + self.head_dim = arch.head_dim + self.base = arch.rope_theta + self.attention_scaling = 1.0 + + half = self.head_dim // 2 + inv_freq = 1.0 / ( + self.base ** ( + torch.arange(0, half, dtype=torch.int64).float().to(device) / + half)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + def forward( + self, x: torch.Tensor, position_ids: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: [B, seq, head_dim] (used only for device/dtype) + position_ids: [B, seq] + Returns: + (cos, sin) each [B, seq, head_dim] + """ + inv_freq_expanded = (self.inv_freq[None, :, None].float().expand( + position_ids.shape[0], -1, 1)) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ + position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + half = x.shape[-1] // 2 + return torch.cat([-x[..., half:], x[..., :half]], dim=-1) + + +def _apply_rope(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, + sin: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Apply rotary embeddings. q/k: [B, heads, seq, head_dim].""" + cos = cos.unsqueeze(1) # [B, 1, seq, head_dim] + sin = sin.unsqueeze(1) + q_rot = q * cos + _rotate_half(q) * sin + k_rot = k * cos + _rotate_half(k) * sin + return q_rot.to(q.dtype), k_rot.to(k.dtype) + + +# --------------------------------------------------------------------------- +# MLP +# --------------------------------------------------------------------------- + + +class Qwen3MLP(nn.Module): + """SwiGLU MLP with Tensor-Parallel linear layers.""" + + def __init__(self, config: Qwen3Config, + quant_config: QuantizationConfig | None = None, + prefix: str = ""): + super().__init__() + arch = config.arch_config + self.gate_up_proj = MergedColumnParallelLinear( + input_size=arch.hidden_size, + output_sizes=[arch.intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=arch.intermediate_size, + output_size=arch.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) + self.act_fn = SiluAndMul() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x + + +# --------------------------------------------------------------------------- +# Attention +# --------------------------------------------------------------------------- + + +class Qwen3Attention(nn.Module): + """ + GQA attention with QK-Norm and Tensor Parallelism. + + QK-Norm is the main architectural difference from Qwen2: after splitting + QKV, RMSNorm is applied to Q and K before computing attention. + """ + + def __init__(self, config: Qwen3Config, layer_idx: int, + quant_config: QuantizationConfig | None = None, + prefix: str = ""): + super().__init__() + arch = config.arch_config + self.hidden_size = arch.hidden_size + self.head_dim = arch.head_dim + self.num_heads = arch.num_attention_heads + self.num_kv_heads = arch.num_key_value_heads + self.scaling = self.head_dim ** -0.5 + + tp_size = get_tp_world_size() + assert self.num_heads % tp_size == 0 + self.num_heads_per_rank = self.num_heads // tp_size + # KV heads: if fewer than tp_size, replicate across ranks + self.num_kv_heads_per_rank = max(1, self.num_kv_heads // tp_size) + + self.q_size = self.num_heads_per_rank * self.head_dim + self.kv_size = self.num_kv_heads_per_rank * self.head_dim + self.num_kv_groups = self.num_heads_per_rank // self.num_kv_heads_per_rank + + # Qwen3 uses attention_bias=False for QKV, but True for o_proj + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.num_heads, + total_num_kv_heads=self.num_kv_heads, + bias=arch.attention_bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.num_heads * self.head_dim, + output_size=self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + # QK-Norm: Qwen3-specific + self.q_norm = RMSNorm(self.head_dim, eps=arch.rms_norm_eps) + self.k_norm = RMSNorm(self.head_dim, eps=arch.rms_norm_eps) + + # Sliding window (layer-dependent in Qwen3) + layer_types = arch.layer_types or [] + layer_type = (layer_types[layer_idx] + if layer_idx < len(layer_types) else "full_attention") + self.sliding_window = (arch.sliding_window + if layer_type == "sliding_attention" else None) + + @staticmethod + def _repeat_kv(hidden: torch.Tensor, n_rep: int) -> torch.Tensor: + if n_rep == 1: + return hidden + B, heads, seq, dim = hidden.shape + return hidden[:, :, None, :, :].expand(B, heads, n_rep, seq, + dim).reshape(B, heads * n_rep, + seq, dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + bsz, q_len, _ = hidden_states.size() + + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + # [B, seq, heads, head_dim] + q = q.view(bsz, q_len, self.num_heads_per_rank, self.head_dim) + k = k.view(bsz, q_len, self.num_kv_heads_per_rank, self.head_dim) + v = v.view(bsz, q_len, self.num_kv_heads_per_rank, self.head_dim) + + # QK-Norm (Qwen3-specific) + q = self.q_norm(q).to(v.dtype) + k = self.k_norm(k).to(v.dtype) + + # [B, heads, seq, head_dim] for RoPE + attention + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # RoPE + if position_embeddings is not None: + cos, sin = position_embeddings + q, k = _apply_rope(q, k, cos, sin) + + # GQA: expand KV to match Q head count + k = self._repeat_kv(k, self.num_kv_groups) + v = self._repeat_kv(v, self.num_kv_groups) + + # SDPA + if attention_mask is not None and attention_mask.dtype != torch.bool: + attention_mask = attention_mask.bool() + + attn_out = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attention_mask, + dropout_p=0.0, + scale=self.scaling, + is_causal=False, + ) + + # [B, seq, hidden] + attn_out = attn_out.transpose(1, 2).contiguous().reshape(bsz, q_len, -1) + attn_out, _ = self.o_proj(attn_out) + return attn_out + + +# --------------------------------------------------------------------------- +# Decoder layer +# --------------------------------------------------------------------------- + + +class Qwen3DecoderLayer(nn.Module): + + def __init__(self, config: Qwen3Config, layer_idx: int, + quant_config: QuantizationConfig | None = None, + prefix: str = ""): + super().__init__() + arch = config.arch_config + self.self_attn = Qwen3Attention(config, + layer_idx, + quant_config=quant_config, + prefix=f"{prefix}.self_attn") + self.mlp = Qwen3MLP(config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + self.input_layernorm = RMSNorm(arch.hidden_size, eps=arch.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(arch.hidden_size, + eps=arch.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + return residual + hidden_states + + +# --------------------------------------------------------------------------- +# Main model +# --------------------------------------------------------------------------- + + +class Qwen3Model(TextEncoder): + """ + Native FastVideo Qwen3 text encoder. + + Supports: + - Tensor Parallelism (QKVParallelLinear + RowParallelLinear) + - INT8/FP8 quantization (via quant_config) + - FSDP sharding (conditions defined in Qwen3ArchConfig) + - Proper HuggingFace weight loading (stacked_params_mapping) + """ + + def __init__(self, config: Qwen3Config): + super().__init__(config) + arch = config.arch_config + quant_config = getattr(config, "quant_config", None) + + self.embed_tokens = VocabParallelEmbedding( + arch.vocab_size, + arch.hidden_size, + org_num_embeddings=arch.vocab_size, + ) + self.layers = nn.ModuleList([ + Qwen3DecoderLayer( + config, + layer_idx, + quant_config=quant_config, + prefix=f"model.layers.{layer_idx}", + ) for layer_idx in range(arch.num_hidden_layers) + ]) + self.norm = RMSNorm(arch.hidden_size, eps=arch.rms_norm_eps) + self.rotary_emb = Qwen3RotaryEmbedding(config) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_hidden_states: Optional[bool] = None, + **kwargs, + ) -> BaseEncoderOutput: + arch = self.config.arch_config + output_hidden_states = (output_hidden_states if output_hidden_states + is not None else + getattr(arch, "output_hidden_states", False)) + + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + + hidden_states = inputs_embeds + seq_length = hidden_states.shape[1] + + if position_ids is None: + position_ids = torch.arange(seq_length, + device=hidden_states.device).unsqueeze(0) + + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # Build SDPA attention mask + cache_position = torch.arange(seq_length, device=hidden_states.device) + sdpa_attn_mask = None + if attention_mask is not None: + sdpa_attn_mask = sdpa_mask( + batch_size=hidden_states.shape[0], + cache_position=cache_position, + kv_length=attention_mask.shape[-1], + kv_offset=0, + attention_mask=attention_mask, + ) + + all_hidden_states: tuple | None = () if output_hidden_states else None + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + hidden_states = layer( + hidden_states=hidden_states, + attention_mask=sdpa_attn_mask, + position_ids=position_ids, + position_embeddings=position_embeddings, + ) + + hidden_states = self.norm(hidden_states) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + return BaseEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + ) + + def load_weights( + self, + weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """ + Load weights from HuggingFace checkpoint using stacked_params_mapping + to fuse q_proj/k_proj/v_proj -> qkv_proj and gate_proj/up_proj -> gate_up_proj. + """ + arch = self.config.arch_config + stacked_params_mapping = getattr(arch, "stacked_params_mapping", []) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + # Try stacked param mapping first + matched = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name_mapped = name.replace(weight_name, param_name) + if name_mapped not in params_dict: + continue + param = params_dict[name_mapped] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + loaded_params.add(name_mapped) + matched = True + break + + if matched: + continue + + if name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params diff --git a/fastvideo/models/registry.py b/fastvideo/models/registry.py index cba1b5d6fa..5806e04262 100644 --- a/fastvideo/models/registry.py +++ b/fastvideo/models/registry.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/model_executor/models/registry.py -import ast import importlib import os import pickle @@ -25,8 +24,6 @@ _TEXT_TO_VIDEO_DIT_MODELS = { "HunyuanVideoTransformer3DModel": ("dits", "hunyuanvideo", "HunyuanVideoTransformer3DModel"), - "HunyuanGameCraftTransformer3DModel": - ("dits", "hunyuangamecraft", "HunyuanGameCraftTransformer3DModel"), "HunyuanVideo15Transformer3DModel": ("dits", "hunyuanvideo15", "HunyuanVideo15Transformer3DModel"), "HYWorldTransformer3DModel": @@ -39,8 +36,8 @@ "LongCatVideoTransformer3DModel": ("dits", "longcat_video_dit", "LongCatVideoTransformer3DModel"), # Wrapper (Phase 1) "LongCatTransformer3DModel": ("dits", "longcat", "LongCatTransformer3DModel"), # Native (Phase 2) "LTX2Transformer3DModel": ("dits", "ltx2", "LTX2Transformer3DModel"), - "SD3Transformer2DModel": ("dits", "sd3", "SD3Transformer2DModel"), - "LingBotWorldTransformer3DModel": ("dits", "lingbotworld", "LingBotWorldTransformer3DModel"), + # Text-to-Image models + "OvisImageTransformer2DModel": ("dits", "ovisimage", "OvisImageTransformer2DModel"), } _IMAGE_TO_VIDEO_DIT_MODELS = { @@ -53,11 +50,9 @@ _TEXT_ENCODER_MODELS = { "CLIPTextModel": ("encoders", "clip", "CLIPTextModel"), - "CLIPTextModelWithProjection": - ("encoders", "clip", "CLIPTextModelWithProjection"), "LlamaModel": ("encoders", "llama", "LlamaModel"), "UMT5EncoderModel": ("encoders", "t5", "UMT5EncoderModel"), - "T5EncoderModel": ("encoders", "t5_hf", "T5EncoderModel"), + "T5EncoderModel": ("encoders", "t5", "T5EncoderModel"), "STEP1TextEncoder": ("encoders", "stepllm", "STEP1TextEncoder"), "BertModel": ("encoders", "clip", "CLIPTextModel"), "Qwen2_5_VLTextModel": ("encoders", "qwen2_5", "Qwen2_5_VLTextModel"), @@ -65,6 +60,7 @@ "Qwen2_5_VLForConditionalGeneration": ("encoders", "reason1", "Reason1TextEncoder"), "LTX2GemmaTextEncoderModel": ("encoders", "gemma", "LTX2GemmaTextEncoderModel"), + "Qwen3Model": ("encoders", "qwen3", "Qwen3Model"), } _IMAGE_ENCODER_MODELS: dict[str, tuple] = { @@ -75,14 +71,13 @@ } _VAE_MODELS = { + "AutoencoderKL": ("vaes", "autoencoderkl", "AutoencoderKL"), "AutoencoderKLHunyuanVideo": ("vaes", "hunyuanvae", "AutoencoderKLHunyuanVideo"), - "AutoencoderKLCausal3D": ("vaes", "gamecraftvae", "GameCraftVAE"), "AutoencoderKLHYWorld": ("vaes", "hyworldvae", "AutoencoderKLHYWorld"), "AutoencoderKLHunyuanVideo15": ("vaes", "hunyuan15vae", "AutoencoderKLHunyuanVideo15"), "AutoencoderKLWan": ("vaes", "wanvae", "AutoencoderKLWan"), "AutoencoderKLStepvideo": ("vaes", "stepvideovae", "AutoencoderKLStepvideo"), - "AutoencoderKL": ("vaes", "autoencoder_kl", "AutoencoderKL"), "CausalVideoAutoencoder": ("vaes", "ltx2vae", "LTX2CausalVideoAutoencoder"), } @@ -107,12 +102,7 @@ ("schedulers", "scheduling_rcm", "RCMScheduler"), } -_UPSAMPLERS = { - "SRTo720pUpsampler": ("upsamplers", "hunyuan15", "SRTo720pUpsampler"), - "SRTo1080pUpsampler": ("upsamplers", "hunyuan15", "SRTo1080pUpsampler"), -} - -_LEGACY_FAST_VIDEO_MODELS = { +_FAST_VIDEO_MODELS = { **_TEXT_TO_VIDEO_DIT_MODELS, **_IMAGE_TO_VIDEO_DIT_MODELS, **_TEXT_ENCODER_MODELS, @@ -120,102 +110,8 @@ **_VAE_MODELS, **_AUDIO_MODELS, **_SCHEDULERS, - **_UPSAMPLERS, } -MODELS_PATH = os.path.dirname(__file__) - - -@lru_cache(maxsize=None) -def _discover_and_register_models() -> dict[str, tuple[str, str, str]]: - discovered_models: dict[str, tuple[str, str, str]] = {} - for root, dirs, files in os.walk(MODELS_PATH): - dirs[:] = [ - d for d in dirs - if not d.startswith(".") and d != "__pycache__" - ] - - for filename in files: - if not filename.endswith(".py"): - continue - - filepath = os.path.join(root, filename) - try: - with open(filepath, "r", encoding="utf-8") as f: - source = f.read() - tree = ast.parse(source, filename=filename) - - entry_class_node = None - first_class_def = None - - for node in ast.walk(tree): - if isinstance(node, ast.Assign): - for target in node.targets: - if isinstance(target, ast.Name) and target.id == "EntryClass": - entry_class_node = node - break - if first_class_def is None and isinstance(node, ast.ClassDef): - first_class_def = node - - if not entry_class_node or not first_class_def: - continue - - model_cls_name_list: list[str] = [] - value_node = entry_class_node.value - - if isinstance(value_node, ast.Name): - model_cls_name_list.append(value_node.id) - elif isinstance(value_node, (ast.List, ast.Tuple)): - for elt in value_node.elts: - if isinstance(elt, ast.Constant) and isinstance( - elt.value, str): - model_cls_name_list.append(elt.value) - elif isinstance(elt, ast.Name): - model_cls_name_list.append(elt.id) - - if not model_cls_name_list: - continue - - rel_dir = os.path.relpath(root, MODELS_PATH) - if rel_dir == ".": - continue - - rel_parts = rel_dir.split(os.sep) - component_name = rel_parts[0] - sub_parts = rel_parts[1:] - - if filename == "__init__.py": - mod_relname = ".".join(sub_parts) - else: - mod_base = filename[:-3] - mod_relname = ".".join(sub_parts + - [mod_base]) if sub_parts else mod_base - - for model_cls_str in model_cls_name_list: - if model_cls_str in discovered_models: - logger.warning( - "Duplicate architecture found: %s. Overwriting.", - model_cls_str) - discovered_models[model_cls_str] = ( - component_name, - mod_relname, - model_cls_str, - ) - - except Exception as e: - logger.warning("Could not parse %s to find models: %s", - filepath, e) - - return discovered_models - - -_DISCOVERED_MODELS = _discover_and_register_models() -_FAST_VIDEO_MODELS = dict(_DISCOVERED_MODELS) -for model_arch, spec in _LEGACY_FAST_VIDEO_MODELS.items(): - if model_arch in _FAST_VIDEO_MODELS: - continue - _FAST_VIDEO_MODELS[model_arch] = spec - _SUBPROCESS_COMMAND = [sys.executable, "-m", "fastvideo.models.dits.registry"] _T = TypeVar("_T") @@ -447,11 +343,10 @@ def resolve_model_cls( ModelRegistry = _ModelRegistry({ model_arch: _LazyRegisteredModel( - module_name=(f"fastvideo.models.{component_name}.{mod_relname}" - if mod_relname else f"fastvideo.models.{component_name}"), + module_name=f"fastvideo.models.{component_name}.{mod_relname}", component_name=component_name, class_name=cls_name, ) for model_arch, (component_name, mod_relname, cls_name) in _FAST_VIDEO_MODELS.items() -}) \ No newline at end of file +}) diff --git a/fastvideo/pipelines/basic/ovis_image/__init__.py b/fastvideo/pipelines/basic/ovis_image/__init__.py new file mode 100644 index 0000000000..0d33cc44ce --- /dev/null +++ b/fastvideo/pipelines/basic/ovis_image/__init__.py @@ -0,0 +1,6 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Ovis-Image text-to-image pipeline.""" + +from .ovis_image_pipeline import OvisImagePipeline + +__all__ = ["OvisImagePipeline"] diff --git a/fastvideo/pipelines/basic/ovis_image/ovis_image_pipeline.py b/fastvideo/pipelines/basic/ovis_image/ovis_image_pipeline.py new file mode 100644 index 0000000000..1022de8d09 --- /dev/null +++ b/fastvideo/pipelines/basic/ovis_image/ovis_image_pipeline.py @@ -0,0 +1,120 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Ovis-Image text-to-image diffusion pipeline implementation. + +This module implements the Ovis-Image T2I pipeline using Diffusers components directly. +This is Approach A - using existing Diffusers classes for quick integration. +""" + +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.logger import init_logger +from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler) +from fastvideo.pipelines.composed_pipeline_base import ComposedPipelineBase +from fastvideo.pipelines.stages import (ConditioningStage, DecodingStage, + DenoisingStage, InputValidationStage, + LatentPreparationStage, + TextEncodingStage, + TimestepPreparationStage) + +logger = init_logger(__name__) + + +class OvisImagePipeline(ComposedPipelineBase): + """ + Pipeline for Ovis-Image text-to-image generation. + + Ovis-Image is a 7B parameter model optimized for high-quality text rendering + in generated images. It uses: + - OvisImageTransformer2DModel: 2D diffusion transformer + - Qwen3Model: Text encoder based on Ovis2.5-2B + - AutoencoderKL: VAE for image encoding/decoding + - FlowMatchEulerDiscreteScheduler: Flow-matching scheduler + """ + + _required_config_modules = [ + "text_encoder", "tokenizer", "vae", "transformer", "scheduler" + ] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs): + """ + Initialize pipeline-specific configurations. + + Sets up the scheduler with Ovis-Image specific parameters. + """ + # Use the scheduler from model config + # The scheduler is already loaded from the model, we just need to configure it + if self.modules.get("scheduler") is None: + self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=3.0, + use_dynamic_shifting=False, # Disable for Approach A + ) + + # Configure scheduler parameters if needed + scheduler = self.modules["scheduler"] + if hasattr(scheduler, 'config'): + # Disable dynamic shifting for Approach A (requires mu parameter) + scheduler.config.use_dynamic_shifting = False + # Update config if needed based on fastvideo_args + if hasattr(fastvideo_args.pipeline_config, 'flow_shift'): + scheduler.config.shift = fastvideo_args.pipeline_config.flow_shift + + def create_pipeline_stages(self, fastvideo_args: FastVideoArgs): + """ + Set up pipeline stages for Ovis-Image T2I generation. + + Pipeline flow: + 1. Input validation - check dimensions + 2. Text encoding - encode prompt with Qwen3 + 3. Conditioning - prepare CFG guidance + 4. Timestep preparation - setup diffusion schedule + 5. Latent preparation - initialize noise + 6. Denoising - iterative denoising with transformer + 7. Decoding - VAE decode to image + """ + + # Stage 1: Validate input dimensions + self.add_stage(stage_name="input_validation_stage", + stage=InputValidationStage()) + + # Stage 2: Encode text prompts with Qwen3 encoder + self.add_stage(stage_name="prompt_encoding_stage", + stage=TextEncodingStage( + text_encoders=[self.get_module("text_encoder")], + tokenizers=[self.get_module("tokenizer")], + )) + + # Stage 3: Prepare conditioning for classifier-free guidance + self.add_stage(stage_name="conditioning_stage", + stage=ConditioningStage()) + + # Stage 4: Prepare timesteps for diffusion process + self.add_stage(stage_name="timestep_preparation_stage", + stage=TimestepPreparationStage( + scheduler=self.get_module("scheduler"))) + + # Stage 5: Prepare initial latent noise + # For T2I, num_frames=1 (single image) + self.add_stage( + stage_name="latent_preparation_stage", + stage=LatentPreparationStage( + scheduler=self.get_module("scheduler"), + transformer=self.get_module("transformer"), + use_btchw_layout=False # Use standard layout + )) + + # Stage 6: Denoising loop with OvisImageTransformer2DModel + self.add_stage(stage_name="denoising_stage", + stage=DenoisingStage( + transformer=self.get_module("transformer"), + scheduler=self.get_module("scheduler"), + vae=self.get_module("vae"))) + + # Stage 7: Decode latents to image + self.add_stage(stage_name="decoding_stage", + stage=DecodingStage(vae=self.get_module("vae"))) + + +# Entry point for pipeline registry +EntryClass = OvisImagePipeline diff --git a/fastvideo/pipelines/pipeline_registry.py b/fastvideo/pipelines/pipeline_registry.py index b4b9c610a5..53294b00ba 100644 --- a/fastvideo/pipelines/pipeline_registry.py +++ b/fastvideo/pipelines/pipeline_registry.py @@ -16,6 +16,30 @@ logger = init_logger(__name__) +# map pipeline name to folder name +_PIPELINE_NAME_TO_ARCHITECTURE_NAME: dict[str, str] = { + "WanPipeline": "wan", + "WanDMDPipeline": "wan", + "WanImageToVideoPipeline": "wan", + "WanVideoToVideoPipeline": "wan", + "WanCausalDMDPipeline": "wan", + "TurboDiffusionPipeline": "turbodiffusion", + "TurboDiffusionI2VPipeline": "turbodiffusion", + "StepVideoPipeline": "stepvideo", + "HunyuanVideoPipeline": "hunyuan", + "HunyuanVideo15Pipeline": "hunyuan15", + "HYWorldPipeline": "hyworld", + "Cosmos2VideoToWorldPipeline": "cosmos", + "Cosmos2_5Pipeline": "cosmos", + "MatrixGamePipeline": "matrixgame", + "MatrixGameCausalDMDPipeline": "matrixgame", + "LongCatPipeline": "longcat", + "LongCatImageToVideoPipeline": "longcat", + "LongCatVideoContinuationPipeline": "longcat", + "LTX2Pipeline": "ltx2", + "OvisImagePipeline": "ovis_image", +} + _PREPROCESS_WORKLOAD_TYPE_TO_PIPELINE_NAME: dict[WorkloadType, str] = { WorkloadType.I2V: "PreprocessPipelineI2V", WorkloadType.T2V: "PreprocessPipelineT2V", @@ -50,36 +74,44 @@ def choices(cls) -> list[str]: @dataclass class _PipelineRegistry: - # Keyed by pipeline_type -> pipeline_name - # pipelines[pipeline_type][pipeline_name] = pipeline_cls - pipelines: dict[str, dict[str, type[ComposedPipelineBase] - | None]] = field(default_factory=dict) + # Keyed by pipeline_type -> architecture -> pipeline_name + # pipelines[pipeline_type][architecture][pipeline_name] = pipeline_cls + pipelines: dict[str, dict[str, dict[str, type[ComposedPipelineBase] + | None]]] = field(default_factory=dict) - def get_supported_pipelines(self, pipeline_type: PipelineType) -> Set[str]: - """Get supported pipelines for the given pipeline type.""" - return set(self.pipelines.get(pipeline_type.value, {}).keys()) + def get_supported_archs(self, pipeline_name_in_config: str, + pipeline_type: PipelineType) -> Set[str]: + """Get supported architectures, optionally filtered by pipeline type and workload type.""" + arch = _PIPELINE_NAME_TO_ARCHITECTURE_NAME[pipeline_name_in_config] + return set(self.pipelines[pipeline_type.value][arch].keys()) def _load_preprocess_pipeline_cls( - self, - workload_type: WorkloadType) -> type[ComposedPipelineBase] | None: + self, workload_type: WorkloadType, + arch: str) -> type[ComposedPipelineBase] | None: pipeline_name = _PREPROCESS_WORKLOAD_TYPE_TO_PIPELINE_NAME[ workload_type] - return self.pipelines.get(PipelineType.PREPROCESS.value, - {}).get(pipeline_name) + + return self.pipelines[ + PipelineType.PREPROCESS.value][arch][pipeline_name] def _try_load_pipeline_cls( self, pipeline_name_in_config: str, pipeline_type: PipelineType, workload_type: WorkloadType ) -> type[ComposedPipelineBase] | type[LoRAPipeline] | None: """Try to load a pipeline class for the given architecture, pipeline type, and workload type.""" - if pipeline_type.value not in self.pipelines: + arch = _PIPELINE_NAME_TO_ARCHITECTURE_NAME[pipeline_name_in_config] + + if (pipeline_type.value not in self.pipelines + or arch not in self.pipelines[pipeline_type.value]): return None if pipeline_type == PipelineType.PREPROCESS: - return self._load_preprocess_pipeline_cls(workload_type) - elif pipeline_type == PipelineType.BASIC or pipeline_type == PipelineType.TRAINING: - return self.pipelines[pipeline_type.value].get( - pipeline_name_in_config) + return self._load_preprocess_pipeline_cls(workload_type, arch) + elif pipeline_type == PipelineType.BASIC: + return self.pipelines[ + pipeline_type.value][arch][pipeline_name_in_config] + elif pipeline_type == PipelineType.TRAINING: + pass else: raise ValueError(f"Invalid pipeline type: {pipeline_type.value}") @@ -99,28 +131,18 @@ def resolve_pipeline_cls( pipeline_type, workload_type) if pipeline_cls is not None: return pipeline_cls - supported_pipelines = self.get_supported_pipelines(pipeline_type) + supported_archs = self.get_supported_archs(pipeline_name_in_config, + pipeline_type) raise ValueError( - f"Pipeline '{pipeline_name_in_config}' is not supported for pipeline type '{pipeline_type.value}' " + f"Pipeline architecture '{pipeline_name_in_config}' is not supported for pipeline type '{pipeline_type.value}' " f"and workload type '{workload_type.value}'. " - f"Supported pipelines: {supported_pipelines}") + f"Supported architectures: {supported_archs}") +@lru_cache def import_pipeline_classes( pipeline_types: list[PipelineType] | PipelineType | None = None -) -> dict[str, dict[str, type[ComposedPipelineBase] | None]]: - pipeline_types_key: tuple[PipelineType, ...] | PipelineType | None - if isinstance(pipeline_types, list): - pipeline_types_key = tuple(pipeline_types) - else: - pipeline_types_key = pipeline_types - return _import_pipeline_classes_cached(pipeline_types_key) - - -@lru_cache -def _import_pipeline_classes_cached( - pipeline_types: tuple[PipelineType, ...] | PipelineType | None = None -) -> dict[str, dict[str, type[ComposedPipelineBase] | None]]: +) -> dict[str, dict[str, dict[str, type[ComposedPipelineBase] | None]]]: """ Import pipeline classes based on the pipeline type and workload type. @@ -129,16 +151,19 @@ def _import_pipeline_classes_cached( If None, loads all types. Returns: - A two-level nested dictionary: - {pipeline_type: {pipeline_name: pipeline_cls}} - e.g., {"basic": {"WanPipeline": WanPipeline}} + A three-level nested dictionary: + {pipeline_type: {architecture_name: {pipeline_name: pipeline_cls}}} + e.g., {"basic": {"wan": {"WanPipeline": WanPipeline}}} """ - type_to_pipeline_dict: dict[str, dict[str, type[ComposedPipelineBase] - | None]] = {} + type_to_arch_to_pipeline_dict: dict[str, + dict[str, + dict[str, + type[ComposedPipelineBase] + | None]]] = {} package_name: str = "fastvideo.pipelines" # Determine which pipeline types to scan - if isinstance(pipeline_types, tuple): + if isinstance(pipeline_types, list): pipeline_types_to_scan = [ pipeline_type.value for pipeline_type in pipeline_types ] @@ -150,7 +175,8 @@ def _import_pipeline_classes_cached( logger.info("Loading pipelines for types: %s", pipeline_types_to_scan) for pipeline_type_str in pipeline_types_to_scan: - pipeline_dict: dict[str, type[ComposedPipelineBase] | None] = {} + arch_to_pipeline_dict: dict[str, dict[str, type[ComposedPipelineBase] + | None]] = {} # Try to load from pipeline-type-specific directory first pipeline_type_package_name = f"{package_name}.{pipeline_type_str}" @@ -162,45 +188,50 @@ def _import_pipeline_classes_cached( for _, arch, ispkg in pkgutil.iter_modules( pipeline_type_package.__path__): + pipeline_dict: dict[str, type[ComposedPipelineBase] | None] = {} + arch_package_name = f"{pipeline_type_package_name}.{arch}" - if not ispkg: - continue - - arch_package = importlib.import_module(arch_package_name) - for _, module_name, ispkg in pkgutil.walk_packages( - arch_package.__path__, arch_package_name + "."): - if ispkg: - continue - pipeline_module = importlib.import_module(module_name) - if not hasattr(pipeline_module, "EntryClass"): - continue - entry_cls = pipeline_module.EntryClass - entry_cls_list = ([ - entry_cls - ] if not isinstance(entry_cls, list) else entry_cls) - - for pipeline in entry_cls_list: - pipeline_name = pipeline.__name__ - if pipeline_name in pipeline_dict: - logger.warning( - "Duplicate pipeline name '%s' found in %s. Overwriting.", - pipeline_name, pipeline_type_str) - pipeline_dict[pipeline_name] = pipeline + if ispkg: + arch_package = importlib.import_module(arch_package_name) + for _, module_name, ispkg in pkgutil.walk_packages( + arch_package.__path__, arch_package_name + "."): + if not ispkg: + pipeline_module = importlib.import_module( + module_name) + if hasattr(pipeline_module, "EntryClass"): + if isinstance(pipeline_module.EntryClass, list): + for pipeline in pipeline_module.EntryClass: + pipeline_name = pipeline.__name__ + assert ( + pipeline_name not in pipeline_dict + ), f"Duplicated pipeline implementation for {pipeline_name} in {pipeline_type_str}.{arch_package_name}" + pipeline_dict[pipeline_name] = pipeline + else: + pipeline_name = pipeline_module.EntryClass.__name__ + assert ( + pipeline_name not in pipeline_dict + ), f"Duplicated pipeline implementation for {pipeline_name} in {pipeline_type_str}.{arch_package_name}" + pipeline_dict[ + pipeline_name] = pipeline_module.EntryClass + + arch_to_pipeline_dict[arch] = pipeline_dict except ImportError as e: raise ImportError( f"Could not import {pipeline_type_package_name} when importing pipeline classes: {e}" ) from None - type_to_pipeline_dict[pipeline_type_str] = pipeline_dict + type_to_arch_to_pipeline_dict[pipeline_type_str] = arch_to_pipeline_dict # Log summary total_pipelines = sum( - len(pipeline_dict) for pipeline_dict in type_to_pipeline_dict.values()) + len(pipeline_dict) + for arch_to_pipeline_dict in type_to_arch_to_pipeline_dict.values() + for pipeline_dict in arch_to_pipeline_dict.values()) logger.info("Loaded %d pipeline classes across %d types", total_pipelines, len(pipeline_types_to_scan)) - return type_to_pipeline_dict + return type_to_arch_to_pipeline_dict def get_pipeline_registry( diff --git a/fastvideo/pipelines/stages/causal_denoising.py b/fastvideo/pipelines/stages/causal_denoising.py index cbd44fd1e2..e49150a7a7 100644 --- a/fastvideo/pipelines/stages/causal_denoising.py +++ b/fastvideo/pipelines/stages/causal_denoising.py @@ -14,7 +14,7 @@ from fastvideo.attention.backends.sliding_tile_attn import ( SlidingTileAttentionBackend) st_attn_available = True -except ImportError: +except (ImportError, RuntimeError): st_attn_available = False SlidingTileAttentionBackend = None # type: ignore @@ -22,7 +22,7 @@ from fastvideo.attention.backends.video_sparse_attn import ( VideoSparseAttentionBackend) vsa_available = True -except ImportError: +except (ImportError, RuntimeError): vsa_available = False VideoSparseAttentionBackend = None # type: ignore @@ -436,9 +436,9 @@ def _initialize_kv_cache(self, batch_size, dtype, device) -> list[dict]: dtype=dtype, device=device), "global_end_index": - 0, + torch.tensor([0], dtype=torch.long, device=device), "local_end_index": - 0, + torch.tensor([0], dtype=torch.long, device=device), }) return kv_cache1 @@ -494,4 +494,4 @@ def verify_input(self, batch: ForwardBatch, result.add_check( "negative_prompt_embeds", batch.negative_prompt_embeds, lambda x: not batch.do_classifier_free_guidance or V.list_not_empty(x)) - return result + return result \ No newline at end of file diff --git a/fastvideo/pipelines/stages/denoising.py b/fastvideo/pipelines/stages/denoising.py index 0630fdc5b2..535c489a84 100644 --- a/fastvideo/pipelines/stages/denoising.py +++ b/fastvideo/pipelines/stages/denoising.py @@ -32,21 +32,21 @@ from fastvideo.attention.backends.sliding_tile_attn import ( SlidingTileAttentionBackend) st_attn_available = True -except ImportError: +except (ImportError, RuntimeError): st_attn_available = False try: from fastvideo.attention.backends.vmoba import VMOBAAttentionBackend from fastvideo.utils import is_vmoba_available vmoba_attn_available = is_vmoba_available() -except ImportError: +except (ImportError, RuntimeError): vmoba_attn_available = False try: from fastvideo.attention.backends.video_sparse_attn import ( VideoSparseAttentionBackend) vsa_available = True -except ImportError: +except (ImportError, RuntimeError): vsa_available = False logger = init_logger(__name__) @@ -173,14 +173,6 @@ def forward( { "mouse_cond": batch.mouse_cond, "keyboard_cond": batch.keyboard_cond, - "c2ws_plucker_emb": batch.c2ws_plucker_emb, - }, - ) - - camera_kwargs = self.prepare_extra_func_kwargs( - self.transformer.forward, - { - "camera_states": batch.camera_states, }, ) @@ -322,27 +314,6 @@ def forward( t_expand = timestep.repeat(latent_model_input.shape[0], 1) else: t_expand = t.repeat(latent_model_input.shape[0]) - t_expand = t_expand.to(get_local_torch_device()) - - use_meanflow = getattr(self.transformer.config, "use_meanflow", - False) - if use_meanflow: - if i == len(timesteps) - 1: - timesteps_r = torch.tensor( - [0.0], device=get_local_torch_device()) - else: - timesteps_r = timesteps[i + 1] - timesteps_r = timesteps_r.repeat( - latent_model_input.shape[0]) - else: - timesteps_r = None - - timesteps_r_kwarg = self.prepare_extra_func_kwargs( - self.transformer.forward, - { - "timestep_r": timesteps_r, - }, - ) latent_model_input = self.scheduler.scale_model_input( latent_model_input, t) @@ -435,8 +406,6 @@ def forward( **image_kwargs, **pos_cond_kwargs, **action_kwargs, - **camera_kwargs, - **timesteps_r_kwarg, ) if batch.do_classifier_free_guidance: @@ -454,8 +423,6 @@ def forward( **image_kwargs, **neg_cond_kwargs, **action_kwargs, - **camera_kwargs, - **timesteps_r_kwarg, ) noise_pred_text = noise_pred @@ -520,7 +487,7 @@ def forward( mgr2.release_all() # Save STA mask search results if needed - if st_attn_available and self.attn_backend == SlidingTileAttentionBackend and fastvideo_args.pipeline_config.STA_mode == STA_Mode.STA_SEARCHING: + if st_attn_available and self.attn_backend == SlidingTileAttentionBackend and fastvideo_args.STA_mode == STA_Mode.STA_SEARCHING: self.save_sta_search_results(batch) # deallocate transformer if on mps @@ -613,8 +580,8 @@ def prepare_sta_param(self, batch: ForwardBatch, """ # TODO(kevin): STA mask search, currently only support Wan2.1 with 69x768x1280 from fastvideo.attention.backends.STA_configuration import configure_sta - STA_mode = fastvideo_args.pipeline_config.STA_mode - skip_time_steps = fastvideo_args.pipeline_config.skip_time_steps + STA_mode = fastvideo_args.STA_mode + skip_time_steps = fastvideo_args.skip_time_steps if batch.timesteps is None: raise ValueError("Timesteps must be provided") timesteps_num = batch.timesteps.shape[0] @@ -1090,6 +1057,7 @@ def forward( "latents must be provided for Cosmos25DenoisingStage") guidance_scale = batch.guidance_scale + # Use timesteps prepared by Cosmos25TimestepPreparationStage when available. if batch.timesteps is None: self.scheduler.set_timesteps(batch.num_inference_steps, device=latents.device) @@ -1097,45 +1065,20 @@ def forward( else: timesteps = batch.timesteps.to(latents.device) - cfg = fastvideo_args.pipeline_config - - if batch.fps is None: - gen = batch.generator - if isinstance(gen, list) and len(gen) > 0: - gen = gen[0] - fps_tensor = torch.randint( - 16, - 32, - (1, ), - generator=gen if isinstance(gen, torch.Generator) else None, - device=latents.device, - ).float().to(dtype=target_dtype) - else: - fps_val = batch.fps - fps_tensor = torch.tensor( - [fps_val], - device=latents.device, - dtype=target_dtype, - ) + # Match official behavior: pass fps as a tensor. + fps_val = batch.fps if isinstance(batch.fps, int | float) else 24 + fps_tensor = torch.tensor([fps_val], + device=latents.device, + dtype=target_dtype) + # Cosmos2.5 denoises a 4D latent (C,T,H,W) and the scheduler.step expects (B,C,T,H,W). latents_4d = latents[0] - # Masks are optional for T2W. - cond_mask = getattr(batch, "cond_mask", None) - condition_mask = cond_mask.to(target_dtype) if isinstance( - cond_mask, torch.Tensor) else None - pad_mask = getattr(batch, "padding_mask", None) - padding_mask = pad_mask.to(target_dtype) if isinstance( - pad_mask, torch.Tensor) else None - - # Conditioning fields are attached by latent preparation stage. - conditioning_latents = getattr(batch, "conditioning_latents", None) - cond_indicator = getattr(batch, "cond_indicator", None) - # Infer whether this is a conditioned run (V2W/I2W) purely from the presence - # of conditioning latents. Avoid carrying explicit mode flags on the batch. - is_conditioned = (conditioning_latents is not None) - - init_noise_4d = latents_4d.clone() + # Masks from latent prep stage + condition_mask = batch.cond_mask.to(target_dtype) if hasattr( + batch, 'cond_mask') else None + padding_mask = batch.padding_mask.to(target_dtype) if hasattr( + batch, 'padding_mask') else None if condition_mask is None: _, t, h, w = latents_4d.shape condition_mask = torch.zeros(1, @@ -1147,58 +1090,23 @@ def forward( dtype=target_dtype) if padding_mask is None: _, _, h, w = latents_4d.shape - padding_default = 0.0 if is_conditioned else 1.0 - padding_mask = torch.full( - (1, 1, h, w), - float(padding_default), - device=latents.device, - dtype=target_dtype, - ) - + padding_mask = torch.ones(1, + 1, + h, + w, + device=latents.device, + dtype=target_dtype) + + # Cosmos2.5 timestep scaling (see compare_pipelines.py): t * 0.001 timestep_scale = 0.001 - state_dtype = torch.float32 - - conditional_frame_timestep = 0.1 - latents_4d = latents_4d.to(state_dtype) - init_noise_4d = init_noise_4d.to(state_dtype) - - clamp_every_step = bool(getattr(cfg, "cosmos25_clamp_every_step", - True)) if is_conditioned else False - with self.progress_bar(total=len(timesteps)) as progress_bar: for i, t in enumerate(timesteps): t_val = float(t) - if is_conditioned: - t_frames = int(latents_4d.shape[1]) - timestep = torch.full( - (1, t_frames), - float(t_val * timestep_scale), - device=latents.device, - dtype=torch.float32, - ) - if cond_indicator is not None and t_frames > 0: - cond_t = cond_indicator[0, 0, :t_frames, 0, 0] - cond_mask_t = (cond_t > 0.5) - if bool(cond_mask_t.any().item()): - timestep[0, cond_mask_t] = float( - conditional_frame_timestep) - else: - timestep_val = t_val * timestep_scale - timestep = torch.tensor( - [[float(timestep_val)]], - device=latents.device, - dtype=target_dtype, - ) - - # Conditioned runs: replace x_t with GT x0 on the conditioned frames. - if (is_conditioned and cond_indicator is not None - and conditioning_latents is not None - and (clamp_every_step or i == 0)): - cond_ind_4d = cond_indicator[0].to(state_dtype) - gt_x0 = conditioning_latents[0].to(state_dtype) - latents_4d = gt_x0 * cond_ind_4d + latents_4d * ( - 1 - cond_ind_4d) + timestep_val = t_val * timestep_scale + timestep = torch.tensor([[timestep_val]], + device=latents.device, + dtype=target_dtype) model_hidden_states = latents_4d.unsqueeze(0) @@ -1232,22 +1140,10 @@ def forward( padding_mask=padding_mask, return_dict=False, )[0] - if is_conditioned: - v = cond_v + guidance_scale * (cond_v - uncond_v) - else: - v = uncond_v + guidance_scale * (cond_v - uncond_v) + v = uncond_v + guidance_scale * (cond_v - uncond_v) else: v = cond_v - # Conditioned runs: replace velocity on conditioned frames with GT velocity. - if (is_conditioned and cond_indicator is not None - and conditioning_latents is not None): - cond_ind_4d = cond_indicator[0].to(state_dtype) - gt_x0 = conditioning_latents[0].to(state_dtype) - gt_v = init_noise_4d.to(state_dtype) - gt_x0 - v = cond_ind_4d * gt_v + (1 - - cond_ind_4d) * v.to(state_dtype) - prev = self.scheduler.step(v.unsqueeze(0), t, latents_4d.unsqueeze(0), @@ -1257,79 +1153,10 @@ def forward( progress_bar.update() - batch.latents = latents_4d.to(target_dtype).unsqueeze(0) + batch.latents = latents_4d.unsqueeze(0) return batch -class Cosmos25T2WDenoisingStage(Cosmos25DenoisingStage): - """Cosmos 2.5 Text2World denoising stage.""" - - _CONDITIONING_FIELDS = ( - "conditioning_latents", - "cond_indicator", - "uncond_indicator", - ) - - def forward( - self, - batch: ForwardBatch, - fastvideo_args: FastVideoArgs, - ) -> ForwardBatch: - for name in self._CONDITIONING_FIELDS: - if hasattr(batch, name): - setattr(batch, name, None) - return super().forward(batch, fastvideo_args) - - -class Cosmos25V2WDenoisingStage(Cosmos25DenoisingStage): - """Cosmos 2.5 Video2World denoising stage.""" - - def forward( - self, - batch: ForwardBatch, - fastvideo_args: FastVideoArgs, - ) -> ForwardBatch: - return super().forward(batch, fastvideo_args) - - -class Cosmos25AutoDenoisingStage(PipelineStage): - """Route Cosmos 2.5 denoising to T2W vs V2W/I2W.""" - - def __init__(self, transformer, scheduler) -> None: - super().__init__() - self._t2w = Cosmos25T2WDenoisingStage(transformer=transformer, - scheduler=scheduler) - self._v2w = Cosmos25V2WDenoisingStage(transformer=transformer, - scheduler=scheduler) - - def pipeline(self): - return self._v2w.pipeline() if self._v2w.pipeline else None - - def forward( - self, - batch: ForwardBatch, - fastvideo_args: FastVideoArgs, - ) -> ForwardBatch: - conditioning_latents = getattr(batch, "conditioning_latents", None) - if conditioning_latents is not None: - return self._v2w.forward(batch, fastvideo_args) - return self._t2w.forward(batch, fastvideo_args) - - def verify_input(self, batch: ForwardBatch, - fastvideo_args: FastVideoArgs) -> VerificationResult: - conditioning_latents = getattr(batch, "conditioning_latents", None) - if conditioning_latents is not None: - return self._v2w.verify_input(batch, fastvideo_args) - return self._t2w.verify_input(batch, fastvideo_args) - - def verify_output(self, batch: ForwardBatch, - fastvideo_args: FastVideoArgs) -> VerificationResult: - conditioning_latents = getattr(batch, "conditioning_latents", None) - if conditioning_latents is not None: - return self._v2w.verify_output(batch, fastvideo_args) - return self._t2w.verify_output(batch, fastvideo_args) - - class DmdDenoisingStage(DenoisingStage): """ Denoising stage for DMD. diff --git a/fastvideo/pipelines/stages/matrixgame_denoising.py b/fastvideo/pipelines/stages/matrixgame_denoising.py index 7d46f146eb..88b9272612 100644 --- a/fastvideo/pipelines/stages/matrixgame_denoising.py +++ b/fastvideo/pipelines/stages/matrixgame_denoising.py @@ -19,7 +19,7 @@ from fastvideo.attention.backends.sliding_tile_attn import ( SlidingTileAttentionBackend) st_attn_available = True -except ImportError: +except (ImportError, RuntimeError): st_attn_available = False SlidingTileAttentionBackend = None # type: ignore @@ -27,7 +27,7 @@ from fastvideo.attention.backends.video_sparse_attn import ( VideoSparseAttentionBackend) vsa_available = True -except ImportError: +except (ImportError, RuntimeError): vsa_available = False VideoSparseAttentionBackend = None # type: ignore @@ -335,9 +335,9 @@ def _initialize_kv_cache(self, batch_size: int, dtype: torch.dtype, dtype=dtype, device=device), "global_end_index": - 0, + torch.tensor([0], dtype=torch.long, device=device), "local_end_index": - 0, + torch.tensor([0], dtype=torch.long, device=device), }) return kv_cache @@ -373,9 +373,9 @@ def _initialize_action_kv_cache(self, batch_size: int, dtype: torch.dtype, dtype=dtype, device=device), "global_end_index": - 0, + torch.tensor([0], dtype=torch.long, device=device), "local_end_index": - 0, + torch.tensor([0], dtype=torch.long, device=device), }) kv_cache_mouse.append({ "k": @@ -393,9 +393,9 @@ def _initialize_action_kv_cache(self, batch_size: int, dtype: torch.dtype, dtype=dtype, device=device), "global_end_index": - 0, + torch.tensor([0], dtype=torch.long, device=device), "local_end_index": - 0, + torch.tensor([0], dtype=torch.long, device=device), }) return kv_cache_mouse, kv_cache_keyboard diff --git a/fastvideo/registry.py b/fastvideo/registry.py index cd7275d6c2..cd83a92c96 100644 --- a/fastvideo/registry.py +++ b/fastvideo/registry.py @@ -47,6 +47,7 @@ WanT2V480PConfig, WanT2V720PConfig, ) +from fastvideo.configs.pipelines.ovis_image import OvisImageT2IConfig from fastvideo.configs.pipelines.sd35 import SD35Config from fastvideo.configs.sample.base import SamplingParam from fastvideo.configs.sample.cosmos import ( @@ -571,6 +572,19 @@ def _register_configs() -> None: ], ) + # Ovis-Image + register_configs( + sampling_param_cls=None, + pipeline_config_cls=OvisImageT2IConfig, + hf_model_paths=[ + "AIDC-AI/Ovis-Image-7B", + ], + model_detectors=[ + lambda path: any(token in path.lower() + for token in ("ovis-image", "ovis_image")), + ], + ) + # SD3.5 register_configs( sampling_param_cls=SD35SamplingParam, diff --git a/fastvideo/tests/encoders/test_qwen3_encoder.py b/fastvideo/tests/encoders/test_qwen3_encoder.py new file mode 100644 index 0000000000..368718ee5f --- /dev/null +++ b/fastvideo/tests/encoders/test_qwen3_encoder.py @@ -0,0 +1,146 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Parity test: FastVideo Qwen3Model vs HuggingFace Qwen3Model for Ovis-Image. + +Mirrors the pattern of test_qwen2_5_encoder.py: + - Loads the real Ovis2.5-2B text encoder from local weights + - Compares FastVideo's Qwen3Model output against the HF baseline + - Checks key weight values and final hidden state numerically + +Set OVIS_WEIGHTS env var to the local model root, e.g. + OVIS_WEIGHTS=official_weights/ovis_image \ + pytest fastvideo/tests/encoders/test_qwen3_encoder.py -vs +""" + +import os + +import pytest +import torch +from torch.distributed.tensor import DTensor +from torch.testing import assert_close +from transformers import AutoConfig, AutoTokenizer, Qwen3Model as HFQwen3Model + +from fastvideo.configs.models.encoders import Qwen3Config +from fastvideo.configs.pipelines.base import PipelineConfig +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.forward_context import set_forward_context +from fastvideo.logger import init_logger +from fastvideo.models.loader.component_loader import TextEncoderLoader +from fastvideo.utils import PRECISION_TO_TYPE + +logger = init_logger(__name__) + +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "29509" + +LOCAL_WEIGHTS = os.getenv("OVIS_WEIGHTS", "official_weights/ovis_image") +TEXT_ENCODER_PATH = os.path.join(LOCAL_WEIGHTS, "text_encoder") +TOKENIZER_PATH = os.path.join(LOCAL_WEIGHTS, "tokenizer") + + +@pytest.fixture +def qwen3_model_paths(): + return TEXT_ENCODER_PATH, TOKENIZER_PATH + + +@pytest.mark.skipif( + not os.path.exists(TEXT_ENCODER_PATH), + reason=(f"Ovis-Image text_encoder not found at {TEXT_ENCODER_PATH}. " + f"Set OVIS_WEIGHTS env var or download from AIDC-AI/Ovis-Image-7B.")) +@pytest.mark.usefixtures("distributed_setup") +def test_qwen3_encoder(qwen3_model_paths): + """ + Load Qwen3 via FastVideo's TextEncoderLoader and verify its last_hidden_state + matches the HuggingFace Qwen3Model baseline (fp32, atol=1e-3). + """ + text_encoder_path, tokenizer_path = qwen3_model_paths + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + precision_str = "fp32" + precision = PRECISION_TO_TYPE[precision_str] + + hf_config = AutoConfig.from_pretrained(text_encoder_path) + logger.info(f"Qwen3 config: hidden_size={hf_config.hidden_size}, " + f"layers={hf_config.num_hidden_layers}") + + # ---- HF baseline ---- + hf_model = HFQwen3Model.from_pretrained(text_encoder_path).to( + precision).to(device).eval() + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + + # ---- FastVideo model ---- + args = FastVideoArgs( + model_path=text_encoder_path, + pipeline_config=PipelineConfig( + text_encoder_configs=(Qwen3Config(),), + text_encoder_precisions=(precision_str,), + ), + pin_cpu_memory=False, + ) + loader = TextEncoderLoader() + fv_model = loader.load(text_encoder_path, args) + fv_model = fv_model.to(precision) + fv_model.eval() + + # ---- Weight spot-check ---- + logger.info("Spot-checking weights...") + params_hf = dict(hf_model.named_parameters()) + params_fv = dict(fv_model.named_parameters()) + + weight_names = [ + "norm.weight", + "layers.0.input_layernorm.weight", + "layers.0.post_attention_layernorm.weight", + "layers.0.mlp.down_proj.weight", + ] + for name in weight_names: + if name not in params_hf or name not in params_fv: + logger.warning(f"Weight {name} not present in both models, skipping") + continue + p_hf = params_hf[name].to(device) + p_fv = params_fv[name] + p_fv = (p_fv.to_local() if isinstance(p_fv, DTensor) else p_fv).to(p_hf) + assert p_hf.shape == p_fv.shape, \ + f"Shape mismatch for {name}: HF={p_hf.shape}, FV={p_fv.shape}" + assert_close(p_hf, p_fv, atol=1e-7, rtol=1e-7, + msg=f"Weight mismatch for {name}") + logger.info("Weight spot-check passed.") + + # ---- Forward-pass parity ---- + prompts = [ + "A vibrant sunset over the ocean with vivid colors.", + "The quick brown fox jumps over the lazy dog.", + ] + for prompt in prompts: + logger.info(f"Testing prompt: {prompt!r}") + tokens = tokenizer( + prompt, + return_tensors="pt", + padding="max_length", + max_length=128, + truncation=True, + ).to(device) + + with torch.no_grad(): + hf_out = hf_model( + input_ids=tokens.input_ids, + attention_mask=tokens.attention_mask, + ).last_hidden_state + + with set_forward_context(current_timestep=0, attn_metadata=None): + fv_out = fv_model( + input_ids=tokens.input_ids, + attention_mask=tokens.attention_mask, + ).last_hidden_state + + assert hf_out.shape == fv_out.shape, \ + f"Output shape mismatch: HF={hf_out.shape}, FV={fv_out.shape}" + + max_diff = (hf_out - fv_out).abs().max().item() + mean_diff = (hf_out - fv_out).abs().mean().item() + logger.info(f" max_diff={max_diff:.3e} mean_diff={mean_diff:.3e}") + + atol = 1e-3 if precision_str == "fp32" else 5e-2 + assert max_diff < atol, \ + f"Output max diff {max_diff:.3e} > {atol} for prompt: {prompt!r}" + + logger.info("Qwen3 encoder parity test passed.") diff --git a/fastvideo/tests/ssim/test_ovis_image_similarity.py b/fastvideo/tests/ssim/test_ovis_image_similarity.py new file mode 100644 index 0000000000..40aa283a42 --- /dev/null +++ b/fastvideo/tests/ssim/test_ovis_image_similarity.py @@ -0,0 +1,208 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +SSIM regression test for Ovis-Image-7B text-to-image pipeline. + +Generates a 256×256 image with a fixed seed and compares it against a +committed reference image using MS-SSIM (threshold ≥ 0.98). + +How to create reference images (first time): + 1. Run this test once — it will fail with a FileNotFoundError that + includes the exact cp command needed to bless the output. + 2. Inspect the generated image under + fastvideo/tests/ssim/generated_videos/Ovis-Image-7B/TORCH_SDPA/ + 3. Copy it to the reference folder (command printed by the test). + 4. Commit the reference image and re-run — the test should now pass. + +Usage: + OVIS_WEIGHTS=official_weights/ovis_image \ + pytest fastvideo/tests/ssim/test_ovis_image_similarity.py -vs +""" + +from __future__ import annotations + +import os +import shlex +import logging + +import pytest +import torch + +logger = logging.getLogger(__name__) + +# OVIS_WEIGHTS = os.getenv("OVIS_WEIGHTS", "AIDC-AI/Ovis-Image-7B") +OVIS_WEIGHTS = os.getenv("OVIS_WEIGHTS", "AIDC-AI/Ovis-Image-7B") +MODEL_ID = "Ovis-Image-7B" + +TEST_PROMPTS = [ + 'A vibrant poster with the text "FAST VIDEO" written in bold red letters ' + "on a clean white background. Professional design, high contrast, 4k quality.", +] + + +def _device_reference_folder() -> str: + suffix = "_reference_videos" + device_name = torch.cuda.get_device_name(0) + if "A40" in device_name: + return "A40" + suffix + if "L40S" in device_name: + return "L40S" + suffix + if "H100" in device_name: + return "H100" + suffix + logger.warning( + "Unsupported device for ssim tests: %s; using L40S references", + device_name, + ) + return "L40S" + suffix + + +pytestmark = pytest.mark.filterwarnings( + "ignore:.*torch.jit.script_method.*:DeprecationWarning", +) + + +@pytest.mark.skipif(not torch.cuda.is_available(), + reason="Ovis-Image SSIM test requires CUDA.") +@pytest.mark.parametrize("ATTENTION_BACKEND", ["TORCH_SDPA"]) +@pytest.mark.parametrize("prompt", TEST_PROMPTS) +def test_ovis_image_similarity(prompt: str, ATTENTION_BACKEND: str) -> None: + from fastvideo import VideoGenerator + from fastvideo.tests.utils import ( + compute_video_ssim_torchvision, + write_ssim_results, + ) + + if not os.path.isdir(OVIS_WEIGHTS): + pytest.skip( + f"Ovis-Image weights not found at {OVIS_WEIGHTS} " + "(set OVIS_WEIGHTS env var to local model path or use HF hub ID)") + + old_backend = os.environ.get("FASTVIDEO_ATTENTION_BACKEND") + os.environ["FASTVIDEO_ATTENTION_BACKEND"] = ATTENTION_BACKEND + try: + script_dir = os.path.dirname(os.path.abspath(__file__)) + output_dir = os.path.join(script_dir, "generated_videos", MODEL_ID, + ATTENTION_BACKEND) + os.makedirs(output_dir, exist_ok=True) + + prompt_prefix = prompt[:100].strip() + output_video_name = f"{prompt_prefix}.mp4" + expected_video_path = os.path.join(output_dir, output_video_name) + + # Remove stale output to avoid comparing against a previous run. + for filename in os.listdir(output_dir): + if filename.endswith(".mp4") and filename.startswith(prompt_prefix): + try: + os.remove(os.path.join(output_dir, filename)) + except FileNotFoundError: + pass + + num_inference_steps = 20 + generator = VideoGenerator.from_pretrained( + model_path=OVIS_WEIGHTS, + num_gpus=1, + use_fsdp_inference=False, + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=False, + pin_cpu_memory=False, + sp_size=1, + tp_size=1, + ) + try: + generator.generate_video( + prompt, + output_path=output_dir, + save_video=True, + height=256, + width=256, + num_frames=1, + fps=1, + num_inference_steps=num_inference_steps, + guidance_scale=5.0, + seed=42, + ) + finally: + generator.shutdown() + + # Locate the generated file. + generated_video_path = None + if os.path.exists(expected_video_path): + generated_video_path = expected_video_path + else: + candidates = [ + os.path.join(output_dir, f) for f in os.listdir(output_dir) + if f.endswith(".mp4") and f.startswith(prompt_prefix) + ] + if candidates: + generated_video_path = max(candidates, key=os.path.getmtime) + + assert generated_video_path is not None and os.path.exists( + generated_video_path), ( + f"Output video was not generated under {output_dir} " + f"for prompt '{prompt}'") + + # Locate reference. + device_reference_folder = _device_reference_folder() + reference_folder = os.path.join(script_dir, device_reference_folder, + MODEL_ID, ATTENTION_BACKEND) + + if not os.path.exists(reference_folder): + bless_cmd = ( + f"mkdir -p {shlex.quote(reference_folder)} && " + f"cp {shlex.quote(generated_video_path)} " + f"{shlex.quote(reference_folder)}/") + pytest.fail( + f"Reference folder does not exist: {reference_folder}\n" + f"Generated image saved at: {generated_video_path}\n" + "To bless references, run:\n" + f" {bless_cmd}") + + reference_video_path = os.path.join(reference_folder, output_video_name) + if not os.path.exists(reference_video_path): + reference_video_name = None + for filename in os.listdir(reference_folder): + if filename.endswith(".mp4") and filename.startswith( + prompt_prefix): + reference_video_name = filename + break + if not reference_video_name: + bless_cmd = ( + f"cp {shlex.quote(generated_video_path)} " + f"{shlex.quote(reference_folder)}/") + pytest.fail( + f"Reference image not found for prompt '{prompt}' " + f"under: {reference_folder}\n" + f"Expected name: {output_video_name}\n" + f"Generated image saved at: {generated_video_path}\n" + f"To bless references, run:\n {bless_cmd}") + reference_video_path = os.path.join(reference_folder, + reference_video_name) + + logger.info("Computing SSIM between %s and %s", reference_video_path, + generated_video_path) + ssim_values = compute_video_ssim_torchvision(reference_video_path, + generated_video_path, + use_ms_ssim=True) + + mean_ssim = ssim_values[0] + logger.info("SSIM mean value: %s", mean_ssim) + + write_ssim_results( + output_dir, + ssim_values, + reference_video_path, + generated_video_path, + num_inference_steps, + prompt, + ) + + min_acceptable_ssim = 0.98 + assert mean_ssim >= min_acceptable_ssim, ( + f"SSIM {mean_ssim:.4f} below threshold {min_acceptable_ssim} " + f"for {MODEL_ID} with backend {ATTENTION_BACKEND}") + + finally: + if old_backend is None: + os.environ.pop("FASTVIDEO_ATTENTION_BACKEND", None) + else: + os.environ["FASTVIDEO_ATTENTION_BACKEND"] = old_backend diff --git a/fastvideo/tests/transformers/test_ovisimage.py b/fastvideo/tests/transformers/test_ovisimage.py new file mode 100644 index 0000000000..9212ed00c4 --- /dev/null +++ b/fastvideo/tests/transformers/test_ovisimage.py @@ -0,0 +1,129 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Distributed forward-pass test for OvisImageTransformer2DModel. + +Mirrors the pattern of test_hunyuanvideo.py: + - Uses FastVideo's TransformerLoader to load real weights + - Runs a forward pass with fixed inputs under a distributed environment + - Checks output shape, finiteness, and (if REFERENCE_LATENT is set) + the double-precision sum against a committed reference value + +Set OVIS_WEIGHTS env var to the local model root, e.g. + OVIS_WEIGHTS=official_weights/ovis_image \ + pytest fastvideo/tests/transformers/test_ovisimage.py -vs + +REFERENCE_LATENT is the double-precision sum of the output latent, +computed with seed=42 on bf16. Verified on RTX 5090. +""" + +import os + +import pytest +import torch + +from fastvideo.configs.models.dits import OvisImageTransformer2DModelConfig +from fastvideo.configs.pipelines.base import PipelineConfig +from fastvideo.distributed.parallel_state import (get_sp_parallel_rank, + get_sp_world_size) +from fastvideo.fastvideo_args import FastVideoArgs +from fastvideo.forward_context import set_forward_context +from fastvideo.logger import init_logger +from fastvideo.models.loader.component_loader import TransformerLoader + +logger = init_logger(__name__) + +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "29507" + +LOCAL_WEIGHTS = os.getenv("OVIS_WEIGHTS", "official_weights/ovis_image") +TRANSFORMER_PATH = os.path.join(LOCAL_WEIGHTS, "transformer") + +LOCAL_RANK = 0 +RANK = 0 +WORLD_SIZE = 1 + +# Reference latent: output.double().sum() with seed=42 on L40S GPU. +# Set to None to skip numerical comparison (use this on the first run to +# discover the value, then commit it here). +REFERENCE_LATENT = 292.9996643066406 # bf16, seed=42 (tolerance 1e-2) + + +@pytest.mark.skipif( + not os.path.exists(TRANSFORMER_PATH), + reason=(f"Ovis-Image transformer weights not found at {TRANSFORMER_PATH}. " + f"Set OVIS_WEIGHTS env var or download from AIDC-AI/Ovis-Image-7B.")) +@pytest.mark.usefixtures("distributed_setup") +def test_ovisimage_transformer(): + """ + Load OvisImageTransformer2DModel via TransformerLoader, run a forward pass + with fixed random inputs (seed=42), check shape + finiteness, and + optionally compare the output sum against REFERENCE_LATENT. + """ + torch.cuda.set_device(f"cuda:{LOCAL_RANK}") + torch.manual_seed(42) + + sp_rank = get_sp_parallel_rank() + sp_world_size = get_sp_world_size() + logger.info(f"rank={RANK}, sp_rank={sp_rank}, sp_world_size={sp_world_size}") + + device = torch.device(f"cuda:{LOCAL_RANK}") + precision_str = "bf16" + + args = FastVideoArgs( + model_path=TRANSFORMER_PATH, + dit_cpu_offload=False, + pipeline_config=PipelineConfig( + dit_config=OvisImageTransformer2DModelConfig(), + dit_precision=precision_str, + ), + ) + args.device = device + + loader = TransformerLoader() + model = loader.load(TRANSFORMER_PATH, args) + model.eval() + + # Fixed small inputs (32×32 latents for speed; real inference uses 128×128) + B = 1 + C_vae = 16 # in_channels=64, packed factor=4 → VAE channels = 64//4 = 16 + H, W = 32, 32 + txt_seq = 32 + joint_dim = 2048 # joint_attention_dim + + hidden_states = torch.randn(B, C_vae, H, W, + device=device, dtype=torch.bfloat16) + encoder_hidden_states = torch.randn(B, txt_seq, joint_dim, + device=device, dtype=torch.bfloat16) + # Ovis-Image FastVideo convention: timestep in [0, 1000] + timestep = torch.tensor([500.0], device=device, dtype=torch.bfloat16) + + with torch.no_grad(): + with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): + with set_forward_context(current_timestep=0, + attn_metadata=None, + forward_batch=None): + output = model( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + ) + + # Shape and finiteness + assert output.shape == (B, C_vae, H, W), \ + f"Unexpected output shape: {output.shape}" + assert torch.isfinite(output).all(), \ + "Output contains NaN or Inf — check weight loading or forward pass" + + latent = output.double().sum().item() + logger.info(f"Output latent sum = {latent:.8f}") + + if REFERENCE_LATENT is not None: + diff = abs(REFERENCE_LATENT - latent) + logger.info(f"Reference={REFERENCE_LATENT:.8f}, diff={diff:.2e}") + assert diff < 1e-2, \ + f"Latent sum differs from reference by {diff:.2e} (threshold 1e-2)" + else: + logger.info( + "REFERENCE_LATENT is None — skipping numerical comparison.\n" + "To pin it, set in this file:\n" + f" REFERENCE_LATENT = {latent}") diff --git a/fastvideo/training/__init__.py b/fastvideo/training/__init__.py index fd9d64dbfb..efacb0f08b 100644 --- a/fastvideo/training/__init__.py +++ b/fastvideo/training/__init__.py @@ -1,11 +1,11 @@ from .distillation_pipeline import DistillationPipeline +from .ovis_image_training_pipeline import OvisImageTrainingPipeline from .training_pipeline import TrainingPipeline from .wan_training_pipeline import WanTrainingPipeline -from .ltx2_training_pipeline import LTX2TrainingPipeline __all__ = [ "TrainingPipeline", "WanTrainingPipeline", - "LTX2TrainingPipeline", "DistillationPipeline", + "OvisImageTrainingPipeline", ] diff --git a/fastvideo/training/ovis_image_training_pipeline.py b/fastvideo/training/ovis_image_training_pipeline.py new file mode 100644 index 0000000000..e3c70d44b8 --- /dev/null +++ b/fastvideo/training/ovis_image_training_pipeline.py @@ -0,0 +1,109 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Training pipeline for Ovis-Image-7B text-to-image model. + +Supports: + - Full fine-tuning of the transformer (all parameters) + - LoRA fine-tuning (set lora_training=True in TrainingArgs) + - FSDP sharding (transformer_blocks and single_transformer_blocks are sharded) + - Validation generation using OvisImagePipeline + +Usage: + python -m fastvideo.training.ovis_image_training_pipeline \ + --pretrained-model-name-or-path official_weights/ovis_image \ + --data-path dataset.parquet \ + --train-batch-size 1 \ + --max-train-steps 1000 \ + --learning-rate 1e-5 +""" + +from copy import deepcopy + +from fastvideo.fastvideo_args import FastVideoArgs, TrainingArgs +from fastvideo.logger import init_logger +from fastvideo.models.schedulers.scheduling_flow_match_euler_discrete import ( + FlowMatchEulerDiscreteScheduler) +from fastvideo.pipelines.basic.ovis_image.ovis_image_pipeline import ( + OvisImagePipeline) +from fastvideo.training.training_pipeline import TrainingPipeline + +logger = init_logger(__name__) + + +class OvisImageTrainingPipeline(TrainingPipeline): + """ + Training pipeline for Ovis-Image text-to-image diffusion model. + + Inherits the full training loop from TrainingPipeline (flow-matching MSE + loss, gradient accumulation, LR scheduling, FSDP, checkpointing, etc.) + and adds Ovis-Image-specific initialisation: + + - FlowMatchEulerDiscreteScheduler with flow_shift=3.0 + - Validation using OvisImagePipeline (shared transformer weights) + + Required config modules: scheduler, transformer, text_encoder, tokenizer, vae + """ + + _required_config_modules = [ + "scheduler", + "transformer", + "text_encoder", + "tokenizer", + "vae", + ] + + def initialize_pipeline(self, fastvideo_args: FastVideoArgs) -> None: + """Set up the FlowMatchEuler scheduler with Ovis-Image defaults.""" + pipeline_cfg = fastvideo_args.pipeline_config + flow_shift = getattr(pipeline_cfg, "flow_shift", 3.0) + + self.modules["scheduler"] = FlowMatchEulerDiscreteScheduler( + num_train_timesteps=1000, + shift=flow_shift, + use_dynamic_shifting=False, + ) + logger.info( + "OvisImageTrainingPipeline: scheduler initialised " + "(flow_shift=%s)", flow_shift) + + def initialize_validation_pipeline(self, + training_args: TrainingArgs) -> None: + """ + Build a validation OvisImagePipeline that shares the trained transformer + so validation images reflect the current training state. + """ + logger.info("Initialising validation pipeline...") + args_copy = deepcopy(training_args) + args_copy.inference_mode = True + + self.validation_pipeline = OvisImagePipeline.from_pretrained( + training_args.model_path, + args=args_copy, + inference_mode=True, + loaded_modules={"transformer": self.get_module("transformer")}, + tp_size=training_args.tp_size, + sp_size=training_args.sp_size, + num_gpus=training_args.num_gpus, + pin_cpu_memory=training_args.pin_cpu_memory, + dit_cpu_offload=True, + ) + + +def main(args: TrainingArgs) -> None: + logger.info("Starting Ovis-Image training pipeline...") + pipeline = OvisImageTrainingPipeline.from_pretrained( + args.pretrained_model_name_or_path, args=args) + pipeline.train() + logger.info("Ovis-Image training pipeline finished.") + + +if __name__ == "__main__": + from fastvideo.fastvideo_args import TrainingArgs + from fastvideo.utils import FlexibleArgumentParser + + parser = FlexibleArgumentParser() + parser = TrainingArgs.add_cli_args(parser) + parser = FastVideoArgs.add_cli_args(parser) + args = parser.parse_args() + args.dit_cpu_offload = False + main(args) diff --git a/tests/local_tests/pipelines/test_ovis_image_pipeline_smoke.py b/tests/local_tests/pipelines/test_ovis_image_pipeline_smoke.py new file mode 100644 index 0000000000..d3f238fec1 --- /dev/null +++ b/tests/local_tests/pipelines/test_ovis_image_pipeline_smoke.py @@ -0,0 +1,94 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +End-to-end pipeline smoke test for Ovis-Image-7B. + +Runs a full generate_video() call through VideoGenerator and verifies: + - Output tensor has the expected shape + - Output is finite (no NaN / Inf) + - Output file is written to disk when save_video=True + +No reference-pipeline comparison is needed here because numerical parity +with the Diffusers implementation is already covered at the transformer +level by tests/local_tests/ovis_image/test_ovis_transformer_parity.py. + +Usage: + # With local weights (fastest) + OVIS_WEIGHTS=official_weights/ovis_image \ + pytest tests/local_tests/pipelines/test_ovis_image_pipeline_smoke.py -vs + + # With HuggingFace Hub weights + pytest tests/local_tests/pipelines/test_ovis_image_pipeline_smoke.py -vs +""" + +import os +import tempfile +from pathlib import Path + +import pytest +import torch + +os.environ.setdefault("MASTER_ADDR", "localhost") +os.environ.setdefault("MASTER_PORT", "29521") + +OVIS_WEIGHTS = os.getenv("OVIS_WEIGHTS", "AIDC-AI/Ovis-Image-7B") + + +@pytest.mark.skipif(not torch.cuda.is_available(), + reason="Ovis-Image pipeline smoke test requires CUDA.") +def test_ovis_image_pipeline_smoke(): + """Smoke test: load Ovis-Image-7B and run a single forward pass.""" + from fastvideo import VideoGenerator + + # Use a small resolution and few steps so the test runs quickly. + prompt = ( + 'A vibrant poster with the text "FAST VIDEO" written in bold red ' + "letters on a clean white background. High contrast, 4k quality.") + height = 128 + width = 128 + num_frames = 1 + num_inference_steps = 4 + seed = 42 + + with tempfile.TemporaryDirectory() as tmpdir: + generator = VideoGenerator.from_pretrained( + OVIS_WEIGHTS, + num_gpus=1, + use_fsdp_inference=False, + dit_cpu_offload=False, + vae_cpu_offload=False, + text_encoder_cpu_offload=False, + pin_cpu_memory=False, + ) + try: + result = generator.generate_video( + prompt, + output_path=tmpdir, + save_video=True, + height=height, + width=width, + num_frames=num_frames, + num_inference_steps=num_inference_steps, + guidance_scale=5.0, + seed=seed, + fps=1, + ) + finally: + generator.shutdown() + + # --- shape check --- + samples = result["samples"] + # Expected: (B, C, T, H, W) = (1, 3, 1, 128, 128) + assert samples.ndim == 5, f"Expected 5-D tensor, got shape {samples.shape}" + assert samples.shape[0] == 1 + assert samples.shape[2] == num_frames + assert samples.shape[3] == height + assert samples.shape[4] == width + + # --- finite check --- + assert torch.isfinite(samples).all(), "Output contains NaN or Inf values" + + # --- output file check --- + output_files = list(Path(tmpdir).glob("*.mp4")) + list( + Path(tmpdir).glob("*.png")) + assert len(output_files) > 0, ( + f"No output file was saved under {tmpdir}. Files: {os.listdir(tmpdir)}")