From 6031a6fcaa7d9815fdb96027fb0e808e2845124c Mon Sep 17 00:00:00 2001 From: zhenggf Date: Mon, 22 Jun 2026 16:11:39 +0800 Subject: [PATCH 1/5] fix: enable Hygon 8-card Hunyuan 1.5 i2v inference (cherry picked from commit d60b8f32c7787054faba8fbacaf5c38fac3ffbfb) --- lightx2v/__init__.py | 8 ++- .../hf/hunyuan15/qwen25/model.py | 6 +- .../hunyuan_video/infer/attn_no_pad.py | 28 +++++--- .../feature_caching/transformer_infer.py | 15 +++-- lightx2v/models/runners/default_runner.py | 16 +++++ .../hunyuan_video/hunyuan_video_15_runner.py | 15 +++++ lightx2v/models/runners/vae_postprocess.py | 66 +++++++++++++++++++ .../hf/hunyuanvideo15/hunyuanvideo_15_vae.py | 26 ++++++++ .../ops/attn/hygon_dcu/flash_attn.py | 31 ++++++--- 9 files changed, 184 insertions(+), 27 deletions(-) create mode 100644 lightx2v/models/runners/vae_postprocess.py diff --git a/lightx2v/__init__.py b/lightx2v/__init__.py index a2250870d..538e2ab19 100755 --- a/lightx2v/__init__.py +++ b/lightx2v/__init__.py @@ -2,9 +2,15 @@ __author__ = "LightX2V Contributors" __license__ = "Apache 2.0" +import os + import lightx2v_platform.set_ai_device from lightx2v import common, models, utils -from lightx2v.pipeline import LightX2VPipeline + +if os.getenv("LIGHTX2V_SKIP_PIPELINE_IMPORT", "0").lower() in ("1", "true", "yes", "on"): + LightX2VPipeline = None +else: + from lightx2v.pipeline import LightX2VPipeline __all__ = [ "__version__", diff --git a/lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py b/lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py index d7295219f..7a0357e83 100755 --- a/lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py +++ b/lightx2v/models/input_encoders/hf/hunyuan15/qwen25/model.py @@ -124,7 +124,8 @@ def load_text_encoder( config = AutoConfig.from_pretrained(text_encoder_path) with init_empty_weights(): text_encoder = AutoModel.from_config(config) - text_encoder = text_encoder.language_model + if hasattr(text_encoder, "language_model"): + text_encoder = text_encoder.language_model if text_encoder_quant_scheme in ["int8", "int8-vllm"]: linear_cls = VllmQuantLinearInt8 @@ -157,7 +158,8 @@ def load_text_encoder( else: text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True) - text_encoder = text_encoder.language_model + if hasattr(text_encoder, "language_model"): + text_encoder = text_encoder.language_model text_encoder.final_layer_norm = text_encoder.norm diff --git a/lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py b/lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py index c5409f66a..af3d93a03 100755 --- a/lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py +++ b/lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py @@ -40,12 +40,20 @@ sageattn3_blackwell = None +def unpad_input_compat(*args, **kwargs): + result = unpad_input(*args, **kwargs) + if len(result) == 4: + x_unpad, indices, cu_seqlens, max_s = result + return x_unpad, indices, cu_seqlens, max_s, None + return result + + def flash_attn_no_pad(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None, deterministic=False): batch_size = qkv.shape[0] seqlen = qkv.shape[1] nheads = qkv.shape[-2] x = rearrange(qkv, "b s three h d -> b s (three h d)") - x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input(x, key_padding_mask) + x_unpad, indices, cu_seqlens, max_s, used_seqlens_in_batch = unpad_input_compat(x, key_padding_mask) x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) output_unpad = flash_attn_varlen_qkvpacked_func( @@ -72,9 +80,9 @@ def flash_attn_no_pad_v3(qkv, key_padding_mask, causal=False, dropout_p=0.0, sof batch_size, seqlen, _, nheads, head_dim = qkv.shape query, key, value = qkv.unbind(dim=2) - query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask) - key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask) - value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask) + query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input_compat(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask) + key_unpad, _, cu_seqlens_k, _, _ = unpad_input_compat(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask) + value_unpad, _, _, _, _ = unpad_input_compat(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask) query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads) key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads) @@ -92,9 +100,9 @@ def sage_attn_no_pad_v2(qkv, key_padding_mask, causal=False, dropout_p=0.0, soft batch_size, seqlen, _, nheads, head_dim = qkv.shape query, key, value = qkv.unbind(dim=2) - query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask) - key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask) - value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask) + query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input_compat(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask) + key_unpad, _, cu_seqlens_k, _, _ = unpad_input_compat(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask) + value_unpad, _, _, _, _ = unpad_input_compat(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask) query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads) key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads) @@ -115,9 +123,9 @@ def sage_attn_no_pad_v3(qkv, key_padding_mask, causal=False, dropout_p=0.0, soft batch_size, seqlen, _, nheads, head_dim = qkv.shape query, key, value = qkv.unbind(dim=2) - query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask) - key_unpad, _, cu_seqlens_k, _, _ = unpad_input(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask) - value_unpad, _, _, _, _ = unpad_input(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask) + query_unpad, indices, cu_seqlens_q, max_seqlen_q, _ = unpad_input_compat(rearrange(query, "b s h d -> b s (h d)"), key_padding_mask) + key_unpad, _, cu_seqlens_k, _, _ = unpad_input_compat(rearrange(key, "b s h d -> b s (h d)"), key_padding_mask) + value_unpad, _, _, _, _ = unpad_input_compat(rearrange(value, "b s h d -> b s (h d)"), key_padding_mask) query_unpad = rearrange(query_unpad, "nnz (h d) -> nnz h d", h=nheads) key_unpad = rearrange(key_unpad, "nnz (h d) -> nnz h d", h=nheads) diff --git a/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py b/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py index 56041ec8b..a72f8ed1b 100755 --- a/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py +++ b/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py @@ -4,6 +4,7 @@ import numpy as np import torch import torch.nn.functional as F +import torch.distributed as dist from lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer import HunyuanVideo15OffloadTransformerInfer from lightx2v_platform.base.global_var import AI_DEVICE @@ -142,6 +143,14 @@ def __init__(self, config): self.previous_modulated_input_even = None self.previous_residual_even = None + def _relative_l1_distance(self, current, previous): + diff_sum = (current - previous).abs().float().sum() + prev_sum = previous.abs().float().sum().clamp_min(1e-12) + stats = torch.stack([diff_sum, prev_sum]) + if self.seq_p_group is not None and dist.is_available() and dist.is_initialized(): + dist.all_reduce(stats, op=dist.ReduceOp.SUM, group=self.seq_p_group) + return (stats[0] / stats[1].clamp_min(1e-12)).cpu().item() + def calculate_should_calc(self, img, vec, block): inp = img.clone() vec_ = vec.clone() @@ -167,7 +176,7 @@ def calculate_should_calc(self, img, vec, block): else: rescale_func = np.poly1d(self.coefficients) if self.scheduler.infer_condition: - self.accumulated_rel_l1_distance_odd += rescale_func(((modulated_inp - self.previous_modulated_input_odd).abs().mean() / self.previous_modulated_input_odd.abs().mean()).cpu().item()) + self.accumulated_rel_l1_distance_odd += rescale_func(self._relative_l1_distance(modulated_inp, self.previous_modulated_input_odd)) if self.accumulated_rel_l1_distance_odd < self.teacache_thresh: should_calc = False else: @@ -175,9 +184,7 @@ def calculate_should_calc(self, img, vec, block): self.accumulated_rel_l1_distance_odd = 0 self.previous_modulated_input_odd = modulated_inp else: - self.accumulated_rel_l1_distance_even += rescale_func( - ((modulated_inp - self.previous_modulated_input_even).abs().mean() / self.previous_modulated_input_even.abs().mean()).cpu().item() - ) + self.accumulated_rel_l1_distance_even += rescale_func(self._relative_l1_distance(modulated_inp, self.previous_modulated_input_even)) if self.accumulated_rel_l1_distance_even < self.teacache_thresh: should_calc = False else: diff --git a/lightx2v/models/runners/default_runner.py b/lightx2v/models/runners/default_runner.py index 3e21e4a22..b6ccfa395 100755 --- a/lightx2v/models/runners/default_runner.py +++ b/lightx2v/models/runners/default_runner.py @@ -1,5 +1,6 @@ import gc import os +import time import numpy as np import requests @@ -11,6 +12,7 @@ from requests.exceptions import RequestException from lightx2v.models.runners.base_runner import BaseRunner +from lightx2v.models.runners.vae_postprocess import env_flag, should_skip_rank_postprocess, sync_device_if_available from lightx2v.server.metrics import monitor_cli from lightx2v.utils.envs import * from lightx2v.utils.generate_task_id import generate_task_id @@ -485,7 +487,21 @@ def post_prompt_enhancer(self): return enhanced_prompt def process_images_after_vae_decoder(self): + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + rank0_post_only = env_flag("LIGHTX2V_VAE_RANK0_POST_ONLY", False) + if should_skip_rank_postprocess(self.gen_video_final, rank=rank, enabled=rank0_post_only): + logger.info(f"[VAE_DETAIL] rank={rank} skip postprocess for empty non-rank0 VAE payload") + return {"video": None} + + detail_timing = env_flag("LIGHTX2V_VAE_DETAIL_TIMING", False) + if detail_timing: + sync_device_if_available() + post_start = time.perf_counter() + self.gen_video_final = wan_vae_to_comfy(self.gen_video_final) + if detail_timing: + sync_device_if_available() + logger.info(f"[VAE_DETAIL] rank={rank} wan_vae_to_comfy_s={time.perf_counter() - post_start:.6f}") if "video_frame_interpolation" in self.config: assert self.vfi_model is not None and self.config["video_frame_interpolation"].get("target_fps", None) is not None diff --git a/lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py b/lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py index 560ab895b..8ace35aaf 100755 --- a/lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py +++ b/lightx2v/models/runners/hunyuan_video/hunyuan_video_15_runner.py @@ -4,6 +4,7 @@ import numpy as np import torch +import torch.distributed as dist import torchvision.transforms as transforms from PIL import Image from loguru import logger @@ -13,6 +14,7 @@ from lightx2v.models.input_encoders.hf.hunyuan15.siglip.model import SiglipVisionEncoder from lightx2v.models.networks.hunyuan_video.model import HunyuanVideo15Model from lightx2v.models.runners.default_runner import DefaultRunner +from lightx2v.models.runners.vae_postprocess import crop_spatial_to_size, env_flag from lightx2v.models.schedulers.hunyuan_video.feature_caching.scheduler import HunyuanVideo15SchedulerCaching from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15SRScheduler, HunyuanVideo15Scheduler from lightx2v.models.video_encoders.hf.hunyuanvideo15.hunyuanvideo_15_vae import HunyuanVideo15VAE @@ -122,6 +124,8 @@ def get_latent_shape_with_target_hw(self, origin_size=None): width, height = origin_size target_size = self.config["transformer_model_name"].split("_")[0] target_height, target_width = self.get_closest_resolution_given_original_size((int(width), int(height)), target_size) + self.output_target_height = target_height + self.output_target_width = target_width latent_shape = [ self.config.get("in_channels", 32), (self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1, @@ -430,6 +434,17 @@ def run_vae_decoder(self, latents): if self.sr_version: latents = self.run_sr(latents) images = super().run_vae_decoder(latents) + if env_flag("LIGHTX2V_CROP_PADDED_VAE_OUTPUT", False): + before_shape = tuple(images.shape) if isinstance(images, torch.Tensor) else None + images = crop_spatial_to_size( + images, + getattr(self, "output_target_height", None), + getattr(self, "output_target_width", None), + ) + after_shape = tuple(images.shape) if isinstance(images, torch.Tensor) else None + if before_shape != after_shape: + rank = dist.get_rank() if dist.is_available() and dist.is_initialized() else 0 + logger.info(f"[VAE_DETAIL] rank={rank} crop_padded_vae_output {before_shape} -> {after_shape}") return images @ProfilingContext4DebugL2("Run Encoders") diff --git a/lightx2v/models/runners/vae_postprocess.py b/lightx2v/models/runners/vae_postprocess.py new file mode 100644 index 000000000..7f6889dc2 --- /dev/null +++ b/lightx2v/models/runners/vae_postprocess.py @@ -0,0 +1,66 @@ +import os + +import torch + +from lightx2v_platform.base.global_var import AI_DEVICE + + +def env_flag(name, default=False): + value = os.getenv(name) + if value is None: + return default + return value.strip().lower() in ("1", "true", "yes", "on") + + +def has_no_video_payload(video): + if video is None: + return True + if isinstance(video, torch.Tensor): + return video.numel() == 0 + return False + + +def should_skip_rank_postprocess(video, rank, enabled): + return enabled and rank != 0 and has_no_video_payload(video) + + +def crop_spatial_to_size(video, target_height=None, target_width=None): + if not isinstance(video, torch.Tensor): + return video + if video.numel() == 0: + return video + if target_height is None and target_width is None: + return video + + height_dim, width_dim = _spatial_dims(video) + height = video.shape[height_dim] + width = video.shape[width_dim] + crop_h = min(height, int(target_height)) if target_height is not None else height + crop_w = min(width, int(target_width)) if target_width is not None else width + if crop_h == height and crop_w == width: + return video + + slices = [slice(None)] * video.ndim + slices[height_dim] = slice(0, crop_h) + slices[width_dim] = slice(0, crop_w) + return video[tuple(slices)].contiguous() + + +def sync_device_if_available(): + device_module = getattr(torch, AI_DEVICE, None) + if device_module is None: + return + synchronize = getattr(device_module, "synchronize", None) + if synchronize is not None: + synchronize() + + +def _spatial_dims(video): + if video.ndim != 5: + return -2, -1 + # VAE tensors are usually B,C,T,H,W before wan_vae_to_comfy and + # B,T,H,W,C afterwards. In both layouts, H/W are the two dims before + # channels only for the postprocessed form; rank0 crop is done before it. + if video.shape[1] in (1, 3, 4, 16, 32): + return 3, 4 + return 2, 3 diff --git a/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py b/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py index 52f013e69..c6bb7dc01 100755 --- a/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py +++ b/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py @@ -1,9 +1,11 @@ import math import os +import time from dataclasses import dataclass from typing import Optional, Tuple, Union import numpy as np +from loguru import logger import torch import torch.distributed as dist import torch.nn.functional as F @@ -15,6 +17,7 @@ from torch import Tensor, nn from lightx2v_platform.base.global_var import AI_DEVICE +from lightx2v.models.runners.vae_postprocess import env_flag torch_device_module = getattr(torch, AI_DEVICE) @@ -813,6 +816,8 @@ def decode(self, z): @torch.no_grad() def decode_dist_2d(self, z, world_size_h, world_size_w): + detail_timing = env_flag("LIGHTX2V_VAE_DETAIL_TIMING", False) + rank0_only = env_flag("LIGHTX2V_VAE_DIST_RANK0_ONLY", False) cur_rank = dist.get_rank() cur_rank_h = cur_rank // world_size_w cur_rank_w = cur_rank % world_size_w @@ -850,7 +855,13 @@ def decode_dist_2d(self, z, world_size_h, world_size_w): zs_chunk = z[:, :, :, h_start:h_end, w_start:w_end].contiguous() # Decode the chunk + if detail_timing: + self.device_synchronize() + decode_start = time.perf_counter() images_chunk = self.vae.decode(zs_chunk, return_dict=False)[0] + if detail_timing: + self.device_synchronize() + logger.info(f"[VAE_DETAIL] rank={cur_rank} decode_chunk_s={time.perf_counter() - decode_start:.6f} latent_chunk_shape={tuple(zs_chunk.shape)} decoded_chunk_shape={tuple(images_chunk.shape)}") # Remove padding from decoded chunk spatial_ratio = 16 @@ -880,11 +891,23 @@ def decode_dist_2d(self, z, world_size_h, world_size_w): total_processes = world_size_h * world_size_w full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)] + if detail_timing: + self.device_synchronize() + gather_start = time.perf_counter() dist.all_gather(full_images, images_chunk) self.device_synchronize() + if detail_timing: + logger.info(f"[VAE_DETAIL] rank={cur_rank} all_gather_s={time.perf_counter() - gather_start:.6f}") + + if rank0_only and cur_rank != 0: + if detail_timing: + logger.info(f"[VAE_DETAIL] rank={cur_rank} skip reconstruct after all_gather") + return torch.empty(0, device=z.device, dtype=images_chunk.dtype) # Reconstruct the full image tensor + if detail_timing: + reconstruct_start = time.perf_counter() image_rows = [] for h_idx in range(world_size_h): image_cols = [] @@ -894,6 +917,9 @@ def decode_dist_2d(self, z, world_size_h, world_size_w): image_rows.append(torch.cat(image_cols, dim=4)) images = torch.cat(image_rows, dim=3) + if detail_timing: + self.device_synchronize() + logger.info(f"[VAE_DETAIL] rank={cur_rank} reconstruct_s={time.perf_counter() - reconstruct_start:.6f} full_shape={tuple(images.shape)}") return images diff --git a/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py b/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py index d1fff4fe2..0dcb272c6 100644 --- a/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py +++ b/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py @@ -112,6 +112,11 @@ def half(x): # Compute softmax scale if not provided if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(q.shape[-1]) + + if cu_seqlens_q is not None and cu_seqlens_q.is_cpu: + cu_seqlens_q = cu_seqlens_q.to(q_flat.device, non_blocking=True) + if cu_seqlens_kv is not None and cu_seqlens_kv.is_cpu: + cu_seqlens_kv = cu_seqlens_kv.to(k_flat.device, non_blocking=True) # Use Flash Attention 2.6.1 (ROCm version) with varlen interface if SAPRDE_LINEAR_ATTN and int(os.getenv("USE_SLA", 0)) and q.shape[1] == k.shape[1]: topk_value = float(os.getenv("SPARSE_ATTN_TOPK", "0.5")) @@ -147,7 +152,7 @@ def half(x): output = output.reshape(bs * max_seqlen_q, -1) return output.to(out_dtype) - def _sdpa_fallback(self, q, k, v, cu_seqlens_q, max_seqlen_q, causal=False, dropout_p=0.0): + def _sdpa_fallback(self, q, k, v, cu_seqlens_q, max_seqlen_q, max_seqlen_kv=None, causal=False, dropout_p=0.0): """ Fallback to PyTorch Scaled Dot Product Attention when Flash Attention is not available. @@ -162,17 +167,23 @@ def _sdpa_fallback(self, q, k, v, cu_seqlens_q, max_seqlen_q, causal=False, drop Returns: Output tensor: [B*Lq, C] (flattened batch) """ - # Reshape from flattened format to batched format - bs = cu_seqlens_q.shape[0] - 1 - # Reshape q, k, v to [B, L, Nq, C] - q = q.reshape(bs, max_seqlen_q, q.shape[-2], q.shape[-1]) - k = k.reshape(bs, max_seqlen_q, k.shape[-2], k.shape[-1]) - v = v.reshape(bs, max_seqlen_q, v.shape[-2], v.shape[-1]) + max_seqlen_kv = max_seqlen_q if max_seqlen_kv is None else max_seqlen_kv + + if q.dim() == 4: + bs = q.shape[0] + q_batched = q + k_batched = k + v_batched = v + else: + bs = cu_seqlens_q.shape[0] - 1 + q_batched = q.reshape(bs, max_seqlen_q, q.shape[-2], q.shape[-1]) + k_batched = k.reshape(bs, max_seqlen_kv, k.shape[-2], k.shape[-1]) + v_batched = v.reshape(bs, max_seqlen_kv, v.shape[-2], v.shape[-1]) # Transpose to [B, Nq, L, C] for SDPA - q = q.transpose(1, 2) - k = k.transpose(1, 2) - v = v.transpose(1, 2) + q = q_batched.transpose(1, 2) + k = k_batched.transpose(1, 2) + v = v_batched.transpose(1, 2) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p) From 99a08bb3ac094d59edf66a61910cc46c2673e63f Mon Sep 17 00:00:00 2001 From: zhenggf Date: Mon, 22 Jun 2026 17:20:16 +0800 Subject: [PATCH 2/5] fix: honor SLA topk environment setting (cherry picked from commit e8ee93a79bd20dce2d084e992a8e140710f2c9b6) --- lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py b/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py index 0dcb272c6..a9b9adcb8 100644 --- a/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py +++ b/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py @@ -119,7 +119,7 @@ def half(x): cu_seqlens_kv = cu_seqlens_kv.to(k_flat.device, non_blocking=True) # Use Flash Attention 2.6.1 (ROCm version) with varlen interface if SAPRDE_LINEAR_ATTN and int(os.getenv("USE_SLA", 0)) and q.shape[1] == k.shape[1]: - topk_value = float(os.getenv("SPARSE_ATTN_TOPK", "0.5")) + topk_value = float(os.getenv("SPARSE_ATTN_TOPK", "0.4")) q = q_flat.unsqueeze(0) k = k_flat.unsqueeze(0) @@ -129,7 +129,7 @@ def half(x): q, k, v, - topk=0.4, + topk=topk_value, ) else: output = flash_attn_varlen_func( From fe93150bb50f12c8ffc43da8b3133649f95c0fbc Mon Sep 17 00:00:00 2001 From: zhenggf Date: Tue, 23 Jun 2026 16:46:19 +0800 Subject: [PATCH 3/5] Optimize Hunyuan VAE decode path (cherry picked from commit b066001a517b59e5ddbf8f7dcce4a14a017be46d) --- .../hf/hunyuanvideo15/hunyuanvideo_15_vae.py | 94 +++++++++++++++++-- 1 file changed, 87 insertions(+), 7 deletions(-) diff --git a/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py b/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py index c6bb7dc01..dd34691b4 100755 --- a/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py +++ b/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py @@ -21,6 +21,80 @@ torch_device_module = getattr(torch, AI_DEVICE) +_VAE_CONV_SHAPE_LOGGED: set[tuple] = set() +_VAE_CONV_SHAPE_HOOKED: set[int] = set() + + +def _dist_rank() -> int: + if dist.is_available() and dist.is_initialized(): + return dist.get_rank() + return 0 + + +def _tensor_shape(value): + if isinstance(value, Tensor): + return tuple(value.shape) + if isinstance(value, (list, tuple)): + return tuple(_tensor_shape(item) for item in value) + if isinstance(value, dict): + return {key: _tensor_shape(item) for key, item in value.items()} + return type(value).__name__ + + +def _maybe_register_vae_conv_shape_hooks(module: nn.Module) -> None: + if not env_flag("LIGHTX2V_VAE_CONV_SHAPE_LOG", False): + return + + module_id = id(module) + if module_id in _VAE_CONV_SHAPE_HOOKED: + return + _VAE_CONV_SHAPE_HOOKED.add(module_id) + + rank = _dist_rank() + if rank != 0 and not env_flag("LIGHTX2V_VAE_CONV_SHAPE_LOG_ALL_RANKS", False): + return + + hook_count = 0 + + def make_hook(name): + def hook(conv_module, inputs, output): + weight_shape = tuple(conv_module.weight.shape) if getattr(conv_module, "weight", None) is not None else None + input_shape = _tensor_shape(inputs[0] if inputs else inputs) + output_shape = _tensor_shape(output) + key = ( + rank, + name, + conv_module.__class__.__name__, + weight_shape, + input_shape, + output_shape, + getattr(conv_module, "stride", None), + getattr(conv_module, "padding", None), + getattr(conv_module, "dilation", None), + getattr(conv_module, "groups", None), + ) + if key in _VAE_CONV_SHAPE_LOGGED: + return + _VAE_CONV_SHAPE_LOGGED.add(key) + logger.info( + "[VAE_CONV_SHAPE] " + f"rank={rank} module={name} type={conv_module.__class__.__name__} " + f"input_shape={input_shape} weight_shape={weight_shape} output_shape={output_shape} " + f"kernel_size={getattr(conv_module, 'kernel_size', None)} stride={getattr(conv_module, 'stride', None)} " + f"padding={getattr(conv_module, 'padding', None)} dilation={getattr(conv_module, 'dilation', None)} " + f"groups={getattr(conv_module, 'groups', None)} bias={getattr(conv_module, 'bias', None) is not None} " + f"dtype={getattr(conv_module.weight, 'dtype', None) if getattr(conv_module, 'weight', None) is not None else None}" + ) + + return hook + + for name, child in module.named_modules(): + if isinstance(child, (nn.Conv2d, nn.Conv3d)): + child.register_forward_hook(make_hook(name)) + hook_count += 1 + + logger.info(f"[VAE_CONV_SHAPE] rank={rank} registered {hook_count} conv hooks on {module.__class__.__name__}") + @dataclass class DecoderOutput(BaseOutput): @@ -805,13 +879,18 @@ def encode(self, x): def decode(self, z): z = z / self.vae.config.scaling_factor - self.vae.enable_tiling() - if self.parallel and self.world_size_h is not None and self.world_size_w is not None: - video_frames = self.decode_dist_2d(z, self.world_size_h, self.world_size_w) - self.world_size_h, self.world_size_w = None, None - else: - video_frames = self.vae.decode(z, return_dict=False)[0] - self.vae.disable_tiling() + use_internal_tiling = not env_flag("LIGHTX2V_VAE_DISABLE_INTERNAL_TILING", False) + if use_internal_tiling: + self.vae.enable_tiling() + try: + if self.parallel and self.world_size_h is not None and self.world_size_w is not None: + video_frames = self.decode_dist_2d(z, self.world_size_h, self.world_size_w) + self.world_size_h, self.world_size_w = None, None + else: + video_frames = self.vae.decode(z, return_dict=False)[0] + finally: + if use_internal_tiling: + self.vae.disable_tiling() return video_frames @torch.no_grad() @@ -858,6 +937,7 @@ def decode_dist_2d(self, z, world_size_h, world_size_w): if detail_timing: self.device_synchronize() decode_start = time.perf_counter() + _maybe_register_vae_conv_shape_hooks(self.vae) images_chunk = self.vae.decode(zs_chunk, return_dict=False)[0] if detail_timing: self.device_synchronize() From 4ad81a486e130637f0b33ad74884e9c884442eaa Mon Sep 17 00:00:00 2001 From: zhenggf Date: Tue, 30 Jun 2026 14:36:26 +0800 Subject: [PATCH 4/5] fix: harden Hunyuan VAE runtime edge cases --- .../feature_caching/transformer_infer.py | 5 ++-- lightx2v/models/runners/vae_postprocess.py | 24 ++++++++++++------- .../ops/attn/hygon_dcu/flash_attn.py | 4 ++-- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py b/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py index a72f8ed1b..5fc57b90d 100755 --- a/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py +++ b/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py @@ -147,8 +147,9 @@ def _relative_l1_distance(self, current, previous): diff_sum = (current - previous).abs().float().sum() prev_sum = previous.abs().float().sum().clamp_min(1e-12) stats = torch.stack([diff_sum, prev_sum]) - if self.seq_p_group is not None and dist.is_available() and dist.is_initialized(): - dist.all_reduce(stats, op=dist.ReduceOp.SUM, group=self.seq_p_group) + seq_p_group = getattr(self, "seq_p_group", None) + if seq_p_group is not None and dist.is_available() and dist.is_initialized(): + dist.all_reduce(stats, op=dist.ReduceOp.SUM, group=seq_p_group) return (stats[0] / stats[1].clamp_min(1e-12)).cpu().item() def calculate_should_calc(self, img, vec, block): diff --git a/lightx2v/models/runners/vae_postprocess.py b/lightx2v/models/runners/vae_postprocess.py index 7f6889dc2..9fbf6ffd0 100644 --- a/lightx2v/models/runners/vae_postprocess.py +++ b/lightx2v/models/runners/vae_postprocess.py @@ -56,11 +56,19 @@ def sync_device_if_available(): def _spatial_dims(video): - if video.ndim != 5: - return -2, -1 - # VAE tensors are usually B,C,T,H,W before wan_vae_to_comfy and - # B,T,H,W,C afterwards. In both layouts, H/W are the two dims before - # channels only for the postprocessed form; rank0 crop is done before it. - if video.shape[1] in (1, 3, 4, 16, 32): - return 3, 4 - return 2, 3 + channel_sizes = (1, 3, 4, 16, 32) + if video.ndim == 5: + # B,T,H,W,C after postprocess: last dim is the channel count. + if video.shape[-1] in (1, 3, 4): + return 2, 3 + # B,C,T,H,W before postprocess: dim 1 is latent/image channels. + if video.shape[1] in channel_sizes: + return 3, 4 + elif video.ndim == 4: + # B,H,W,C image tensor. + if video.shape[-1] in (1, 3, 4): + return 1, 2 + # B,C,H,W tensor. + if video.shape[1] in channel_sizes: + return 2, 3 + return -2, -1 diff --git a/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py b/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py index a9b9adcb8..bc9f6b298 100644 --- a/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py +++ b/lightx2v_platform/ops/attn/hygon_dcu/flash_attn.py @@ -113,9 +113,9 @@ def half(x): if softmax_scale is None: softmax_scale = 1.0 / math.sqrt(q.shape[-1]) - if cu_seqlens_q is not None and cu_seqlens_q.is_cpu: + if cu_seqlens_q is not None and cu_seqlens_q.device.type == "cpu": cu_seqlens_q = cu_seqlens_q.to(q_flat.device, non_blocking=True) - if cu_seqlens_kv is not None and cu_seqlens_kv.is_cpu: + if cu_seqlens_kv is not None and cu_seqlens_kv.device.type == "cpu": cu_seqlens_kv = cu_seqlens_kv.to(k_flat.device, non_blocking=True) # Use Flash Attention 2.6.1 (ROCm version) with varlen interface if SAPRDE_LINEAR_ATTN and int(os.getenv("USE_SLA", 0)) and q.shape[1] == k.shape[1]: From 3a4c4abe8f3469cd818dbf185f6debcc5da3e22b Mon Sep 17 00:00:00 2001 From: zhenggf Date: Wed, 1 Jul 2026 14:59:33 +0800 Subject: [PATCH 5/5] style: format Hunyuan VAE runtime changes --- .../infer/feature_caching/transformer_infer.py | 2 +- .../hf/hunyuanvideo15/hunyuanvideo_15_vae.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py b/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py index 5fc57b90d..4906d99ed 100755 --- a/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py +++ b/lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py @@ -3,8 +3,8 @@ import numpy as np import torch -import torch.nn.functional as F import torch.distributed as dist +import torch.nn.functional as F from lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer import HunyuanVideo15OffloadTransformerInfer from lightx2v_platform.base.global_var import AI_DEVICE diff --git a/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py b/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py index dd34691b4..575436be8 100755 --- a/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py +++ b/lightx2v/models/video_encoders/hf/hunyuanvideo15/hunyuanvideo_15_vae.py @@ -5,7 +5,6 @@ from typing import Optional, Tuple, Union import numpy as np -from loguru import logger import torch import torch.distributed as dist import torch.nn.functional as F @@ -14,10 +13,11 @@ from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.models.modeling_utils import ModelMixin from einops import rearrange +from loguru import logger from torch import Tensor, nn -from lightx2v_platform.base.global_var import AI_DEVICE from lightx2v.models.runners.vae_postprocess import env_flag +from lightx2v_platform.base.global_var import AI_DEVICE torch_device_module = getattr(torch, AI_DEVICE) @@ -941,7 +941,9 @@ def decode_dist_2d(self, z, world_size_h, world_size_w): images_chunk = self.vae.decode(zs_chunk, return_dict=False)[0] if detail_timing: self.device_synchronize() - logger.info(f"[VAE_DETAIL] rank={cur_rank} decode_chunk_s={time.perf_counter() - decode_start:.6f} latent_chunk_shape={tuple(zs_chunk.shape)} decoded_chunk_shape={tuple(images_chunk.shape)}") + logger.info( + f"[VAE_DETAIL] rank={cur_rank} decode_chunk_s={time.perf_counter() - decode_start:.6f} latent_chunk_shape={tuple(zs_chunk.shape)} decoded_chunk_shape={tuple(images_chunk.shape)}" + ) # Remove padding from decoded chunk spatial_ratio = 16