Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion lightx2v/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@
__author__ = "LightX2V Contributors"
__license__ = "Apache 2.0"

import os

import lightx2v_platform.set_ai_device
from lightx2v import common, models, utils
from lightx2v.pipeline import LightX2VPipeline

if os.getenv("LIGHTX2V_SKIP_PIPELINE_IMPORT", "0").lower() in ("1", "true", "yes", "on"):
LightX2VPipeline = None
else:
from lightx2v.pipeline import LightX2VPipeline

__all__ = [
"__version__",
Expand Down
6 changes: 4 additions & 2 deletions lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def load_text_encoder(
config = AutoConfig.from_pretrained(text_encoder_path)
with init_empty_weights():
text_encoder = AutoModel.from_config(config)
text_encoder = text_encoder.language_model
if hasattr(text_encoder, "language_model"):
text_encoder = text_encoder.language_model

if text_encoder_quant_scheme in ["int8", "int8-vllm"]:
linear_cls = VllmQuantLinearInt8
Expand Down Expand Up @@ -157,7 +158,8 @@ def load_text_encoder(

else:
text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True)
text_encoder = text_encoder.language_model
if hasattr(text_encoder, "language_model"):
text_encoder = text_encoder.language_model

text_encoder.final_layer_norm = text_encoder.norm

Expand Down
28 changes: 18 additions & 10 deletions lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,20 @@
sageattn3_blackwell = None


def unpad_input_compat(*args, **kwargs):
result = unpad_input(*args, **kwargs)
if len(result) == 4:
x_unpad, indices, cu_seqlens, max_s = result
return x_unpad, indices, cu_seqlens, max_s, None
return result


def flash_attn_no_pad(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False):
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input(x, key_padding_mask)
x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input_compat(x, key_padding_mask)

x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_attn_varlen_qkvpacked_func(
Expand All @@ -72,9 +80,9 @@ def flash_attn_no_pad_v3(qkv, key_padding_mask, causal=False, dropout_p=0.0, sof
batch_size, seqlen, _, nheads, head_dim = qkv.shape
query, key, value = qkv.unbind(dim=2)

query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)
query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input_compat(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input_compat(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input_compat(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)

query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads)
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads)
Expand All @@ -92,9 +100,9 @@ def sage_attn_no_pad_v2(qkv, key_padding_mask, causal=False, dropout_p=0.0, soft
batch_size, seqlen, _, nheads, head_dim = qkv.shape
query, key, value = qkv.unbind(dim=2)

query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)
query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input_compat(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input_compat(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input_compat(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)

query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads)
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads)
Expand All @@ -115,9 +123,9 @@ def sage_attn_no_pad_v3(qkv, key_padding_mask, causal=False, dropout_p=0.0, soft
batch_size, seqlen, _, nheads, head_dim = qkv.shape
query, key, value = qkv.unbind(dim=2)

query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)
query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input_compat(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask)
key_unpad, _, cu_seqlens_k, _, _ = unpad_input_compat(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask)
value_unpad, _, _, _, _ = unpad_input_compat(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask)

query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads)
key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F

from lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer import HunyuanVideo15OffloadTransformerInfer
Expand Down Expand Up @@ -142,6 +143,15 @@ def __init__(self, config):
self.previous_modulated_input_even = None
self.previous_residual_even = None

def _relative_l1_distance(self, current, previous):
diff_sum = (current - previous).abs().float().sum()
prev_sum = previous.abs().float().sum().clamp_min(1e-12)
stats = torch.stack([diff_sum, prev_sum])
seq_p_group = getattr(self, "seq_p_group", None)
if seq_p_group is not None and dist.is_available() and dist.is_initialized():
dist.all_reduce(stats, op=dist.ReduceOp.SUM, group=seq_p_group)
return (stats[0] / stats[1].clamp_min(1e-12)).cpu().item()
Comment thread
starrkk marked this conversation as resolved.

def calculate_should_calc(self, img, vec, block):
inp = img.clone()
vec_ = vec.clone()
Expand All @@ -167,17 +177,15 @@ def calculate_should_calc(self, img, vec, block):
else:
rescale_func = np.poly1d(self.coefficients)
if self.scheduler.infer_condition:
self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_modulated_input_odd).abs().mean() / self.previous_modulated_input_odd.abs().mean()).cpu().item())
self.accumulated_rel_l1_distance_odd += rescale_func(self._relative_l1_distance(modulated_inp, self.previous_modulated_input_odd))
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc = False
else:
should_calc = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_modulated_input_odd = modulated_inp
else:
self.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.previous_modulated_input_even).abs().mean() / self.previous_modulated_input_even.abs().mean()).cpu().item()
)
self.accumulated_rel_l1_distance_even += rescale_func(self._relative_l1_distance(modulated_inp, self.previous_modulated_input_even))
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc = False
else:
Expand Down
16 changes: 16 additions & 0 deletions lightx2v/models/runners/default_runner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import gc
import os
import time

import numpy as np
import requests
Expand All @@ -11,6 +12,7 @@
from requests.exceptions import RequestException

from lightx2v.models.runners.base_runner import BaseRunner
from lightx2v.models.runners.vae_postprocess import env_flag, should_skip_rank_postprocess, sync_device_if_available
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.generate_task_id import generate_task_id
Expand Down Expand Up @@ -485,7 +487,21 @@ def post_prompt_enhancer(self):
return enhanced_prompt

def process_images_after_vae_decoder(self):
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
rank0_post_only = env_flag("LIGHTX2V_VAE_RANK0_POST_ONLY", False)
if should_skip_rank_postprocess(self.gen_video_final, rank=rank, enabled=rank0_post_only):
logger.info(f"[VAE_DETAIL] rank={rank} skip postprocess for empty non-rank0 VAE payload")
return {"video": None}

detail_timing = env_flag("LIGHTX2V_VAE_DETAIL_TIMING", False)
if detail_timing:
sync_device_if_available()
post_start = time.perf_counter()

self.gen_video_final = wan_vae_to_comfy(self.gen_video_final)
if detail_timing:
sync_device_if_available()
logger.info(f"[VAE_DETAIL] rank={rank} wan_vae_to_comfy_s={time.perf_counter() - post_start:.6f}")

if "video_frame_interpolation" in self.config:
assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None
Expand Down
15 changes: 15 additions & 0 deletions lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
import torch.distributed as dist
import torchvision.transforms as transforms
from PIL import Image
from loguru import logger
Expand All @@ -13,6 +14,7 @@
from lightx2v.models.input_encoders.hf.hunyuan15.siglip.model import SiglipVisionEncoder
from lightx2v.models.networks.hunyuan_video.model import HunyuanVideo15Model
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.runners.vae_postprocess import crop_spatial_to_size, env_flag
from lightx2v.models.schedulers.hunyuan_video.feature_caching.scheduler import HunyuanVideo15SchedulerCaching
from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15SRScheduler, HunyuanVideo15Scheduler
from lightx2v.models.video_encoders.hf.hunyuanvideo15.hunyuanvideo_15_vae import HunyuanVideo15VAE
Expand Down Expand Up @@ -122,6 +124,8 @@ def get_latent_shape_with_target_hw(self, origin_size=None):
width, height = origin_size
target_size = self.config["transformer_model_name"].split("_")[0]
target_height, target_width = self.get_closest_resolution_given_original_size((int(width), int(height)), target_size)
self.output_target_height = target_height
self.output_target_width = target_width
latent_shape = [
self.config.get("in_channels", 32),
(self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
Expand Down Expand Up @@ -430,6 +434,17 @@ def run_vae_decoder(self, latents):
if self.sr_version:
latents = self.run_sr(latents)
images = super().run_vae_decoder(latents)
if env_flag("LIGHTX2V_CROP_PADDED_VAE_OUTPUT", False):
before_shape = tuple(images.shape) if isinstance(images, torch.Tensor) else None
images = crop_spatial_to_size(
images,
getattr(self, "output_target_height", None),
getattr(self, "output_target_width", None),
)
after_shape = tuple(images.shape) if isinstance(images, torch.Tensor) else None
if before_shape != after_shape:
rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0
logger.info(f"[VAE_DETAIL] rank={rank} crop_padded_vae_output {before_shape} -> {after_shape}")
return images

@ProfilingContext4DebugL2("Run Encoders")
Expand Down
74 changes: 74 additions & 0 deletions lightx2v/models/runners/vae_postprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os

import torch

from lightx2v_platform.base.global_var import AI_DEVICE


def env_flag(name, default=False):
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in ("1", "true", "yes", "on")


def has_no_video_payload(video):
if video is None:
return True
if isinstance(video, torch.Tensor):
return video.numel() == 0
return False


def should_skip_rank_postprocess(video, rank, enabled):
return enabled and rank != 0 and has_no_video_payload(video)


def crop_spatial_to_size(video, target_height=None, target_width=None):
if not isinstance(video, torch.Tensor):
return video
if video.numel() == 0:
return video
if target_height is None and target_width is None:
return video

height_dim, width_dim = _spatial_dims(video)
height = video.shape[height_dim]
width = video.shape[width_dim]
crop_h = min(height, int(target_height)) if target_height is not None else height
crop_w = min(width, int(target_width)) if target_width is not None else width
if crop_h == height and crop_w == width:
return video

slices = [slice(None)] * video.ndim
slices[height_dim] = slice(0, crop_h)
slices[width_dim] = slice(0, crop_w)
return video[tuple(slices)].contiguous()


def sync_device_if_available():
device_module = getattr(torch, AI_DEVICE, None)
if device_module is None:
return
synchronize = getattr(device_module, "synchronize", None)
if synchronize is not None:
synchronize()


def _spatial_dims(video):
channel_sizes = (1, 3, 4, 16, 32)
if video.ndim == 5:
# B,T,H,W,C after postprocess: last dim is the channel count.
if video.shape[-1] in (1, 3, 4):
return 2, 3
# B,C,T,H,W before postprocess: dim 1 is latent/image channels.
if video.shape[1] in channel_sizes:
return 3, 4
elif video.ndim == 4:
# B,H,W,C image tensor.
if video.shape[-1] in (1, 3, 4):
return 1, 2
# B,C,H,W tensor.
if video.shape[1] in channel_sizes:
return 2, 3
return -2, -1
Loading