diff --git a/lightx2v/common/ops/attn/ulysses_attn.py b/lightx2v/common/ops/attn/ulysses_attn.py index 93fff1eac..a0daf36e1 100755 --- a/lightx2v/common/ops/attn/ulysses_attn.py +++ b/lightx2v/common/ops/attn/ulysses_attn.py @@ -8,6 +8,21 @@ from .template import AttnWeightTemplate from .utils.all2all import all2all_head2seq + +def _is_split_qkv_input(tensor_or_pair): + return isinstance(tensor_or_pair, (tuple, list)) + + +def _to_3d_qkv(tensor): + if len(tensor.shape) == 4: + return tensor.reshape(-1, tensor.shape[-2], tensor.shape[-1]) + return tensor + + +def _contiguous_if_needed(tensor): + return tensor if tensor.is_contiguous() else tensor.contiguous() + + try: from sageattn3_sparse import dequant_fp4 as dequant_fp4_sage3 from sageattn3_sparse import quant_fp4 as quant_fp4_sage3 @@ -21,6 +36,20 @@ class UlyssesAttnWeight(AttnWeightTemplate): def __init__(self): self.config = {} + self._text_gather_buffers = {} + + def _get_text_gather_buffers(self, tensor, world_size, reuse_buffers=False, cache_max=16): + if not reuse_buffers: + return [torch.empty_like(tensor) for _ in range(world_size)] + + key = (world_size, tuple(tensor.shape), tensor.dtype, tensor.device) + buffers = self._text_gather_buffers.get(key) + if buffers is None: + if len(self._text_gather_buffers) >= max(1, cache_max): + self._text_gather_buffers.clear() + buffers = [torch.empty_like(tensor) for _ in range(world_size)] + self._text_gather_buffers[key] = buffers + return buffers def apply( self, @@ -37,6 +66,10 @@ def apply( enable_head_parallel=False, img_first=True, q_only_img=False, + return_split_output=False, + async_text_gather=False, + reuse_text_gather_buffers=False, + text_gather_buffer_cache_max=16, **kwargs, ): """ @@ -62,8 +95,16 @@ def apply( assert not (use_fp8_comm and use_fp4_comm), "use_fp8_comm and use_fp4_comm can't be enabled at the same time." use_qkv_fusion = use_tensor_fusion + split_qkv_input = _is_split_qkv_input(q) - if len(q.shape) == 4: + if split_qkv_input: + if q_only_img: + raise NotImplementedError("split QKV input does not support q_only_img yet.") + assert _is_split_qkv_input(k) and _is_split_qkv_input(v), "q/k/v must all use split img/txt input." + img_q, txt_q = (_to_3d_qkv(tensor) for tensor in q) + img_k, txt_k = (_to_3d_qkv(tensor) for tensor in k) + img_v, txt_v = (_to_3d_qkv(tensor) for tensor in v) + elif len(q.shape) == 4: q = q.reshape(-1, q.shape[-2], q.shape[-1]) k = k.reshape(-1, k.shape[-2], k.shape[-1]) v = v.reshape(-1, v.shape[-2], v.shape[-1]) @@ -73,7 +114,11 @@ def apply( cur_rank = dist.get_rank(seq_p_group) # 获取序列长度和文本相关的长度 - if img_first: + if split_qkv_input: + img_qkv_len = img_q.shape[0] + txt_qkv_len = txt_q.shape[0] + txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len if img_first and len(cu_seqlens_qkv) == 3 else None + elif img_first: img_qkv_len = slice_qkv_len if len(cu_seqlens_qkv) == 3: txt_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len # 文本查询、键和值的长度 @@ -87,8 +132,12 @@ def apply( txt_mask_len = None # 分别获取 q 和 kv 的头数,支持 GQA(k/v 头数可能少于 q) - _, q_heads, hidden_dims = q.shape - _, kv_heads, _ = k.shape + if split_qkv_input: + _, q_heads, hidden_dims = img_q.shape + _, kv_heads, _ = img_k.shape + else: + _, q_heads, hidden_dims = q.shape + _, kv_heads, _ = k.shape is_gqa = q_heads != kv_heads q_shard_heads = q_heads // world_size # q 每个进程处理的头数 kv_shard_heads = kv_heads // world_size # k/v 每个进程处理的头数 @@ -118,7 +167,14 @@ def apply( max_seqlen_q = max_seqlen_kv # 分割图像和文本的查询、键和值 - if q_only_img: + if split_qkv_input: + img_q = _contiguous_if_needed(img_q) + img_k = _contiguous_if_needed(img_k) + img_v = _contiguous_if_needed(img_v) + txt_q = _contiguous_if_needed(txt_q) + txt_k = _contiguous_if_needed(txt_k) + txt_v = _contiguous_if_needed(txt_v) + elif q_only_img: # q 只含图像 token,无需分割;仅 k/v 需要拆出图像和文本部分 img_q = q.contiguous() txt_q = None @@ -462,13 +518,27 @@ def apply( txt_attn, img_attn = attn[:txt_qkv_len, :], attn[txt_qkv_len:] # 通信所有进程的图像注意力结果 + gathered_txt_attn = None + text_gather_work = None + if async_text_gather: + gathered_txt_attn = self._get_text_gather_buffers(txt_attn, world_size, reuse_text_gather_buffers, text_gather_buffer_cache_max) + text_gather_work = dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group, async_op=True) + + # Finish the async gather before launching any later collective on the same process group. + if async_text_gather: + text_gather_work.wait() + img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, q_shard_heads, hidden_dims, seq_p_group, use_fp8_comm) - # 收集所有进程的文本注意力结果 - gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)] - dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group) + # Gather text attention synchronously when async launch is disabled. + if not async_text_gather: + gathered_txt_attn = self._get_text_gather_buffers(txt_attn, world_size, reuse_text_gather_buffers, text_gather_buffer_cache_max) + dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group) txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果 + if return_split_output: + return img_attn, txt_attn + # 合并图像和文本的注意力结果 if img_first: attn = torch.cat([img_attn, txt_attn], dim=0) diff --git a/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py b/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py index 87341b4f6..cd1c2c92a 100755 --- a/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py +++ b/lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py @@ -3,6 +3,7 @@ import torch import torch.nn.functional as F from einops import rearrange +from loguru import logger try: from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace @@ -100,16 +101,29 @@ def __init__(self, config): self.config = config self.double_blocks_num = config["mm_double_blocks_depth"] self.heads_num = config["heads_num"] + parallel_config = self.config.get("parallel", {}) if self.config["seq_parallel"]: self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p") - self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False) - self.seq_p_fp4_comm = self.config["parallel"].get("seq_p_fp4_comm", False) - self.enable_head_parallel = self.config["parallel"].get("seq_p_head_parallel", False) + self.seq_p_fp8_comm = parallel_config.get("seq_p_fp8_comm", False) + self.seq_p_fp4_comm = parallel_config.get("seq_p_fp4_comm", False) + self.enable_head_parallel = parallel_config.get("seq_p_head_parallel", False) + self.seq_p_tensor_fusion = parallel_config.get("seq_p_tensor_fusion", False) + self.seq_p_split_qkv_input = parallel_config.get("seq_p_split_qkv_input", False) + self.seq_p_split_attn_output = parallel_config.get("seq_p_split_attn_output", False) + self.seq_p_async_text_gather = parallel_config.get("seq_p_async_text_gather", False) + self.seq_p_reuse_text_gather_buffers = parallel_config.get("seq_p_reuse_text_gather_buffers", False) + self.seq_p_text_gather_buffer_cache_max = parallel_config.get("seq_p_text_gather_buffer_cache_max", 16) else: self.seq_p_group = None self.seq_p_fp8_comm = False self.seq_p_fp4_comm = False self.enable_head_parallel = False + self.seq_p_tensor_fusion = False + self.seq_p_split_qkv_input = False + self.seq_p_split_attn_output = False + self.seq_p_async_text_gather = False + self.seq_p_reuse_text_gather_buffers = False + self.seq_p_text_gather_buffer_cache_max = 16 self.infer_func = self.infer_without_offload if self.config.get("modulate_type", "triton") == "triton": self.modulate_func = fuse_scale_shift_kernel @@ -119,6 +133,31 @@ def __init__(self, config): self.apply_rope_func = apply_hunyuan_rope_with_flashinfer else: self.apply_rope_func = apply_hunyuan_rope_with_torch + self.compile_non_attn = self.config.get("compile_dit_non_attn", False) + self.compile_before_attn_requested = self.config.get("compile_dit_before_attn", False) + self.compile_before_attn = self.compile_before_attn_requested and self.config.get("compile_dit_before_attn_unsafe", False) + self.compile_non_attn_mode = self.config.get("compile_dit_mode", "reduce-overhead") + self.share_qkv_act_quant = self.config.get("share_qkv_act_quant", False) + self._compiled_non_attn = {} + self._compile_non_attn_failed = set() + if self.share_qkv_act_quant: + logger.info("[Quant] Reusing dynamic activation quantization for consecutive Q/K/V projections") + if self.seq_p_split_qkv_input: + logger.info("[Ulysses] Passing split img/txt QKV tensors to avoid pre-attention concat/slice copies") + if self.seq_p_split_attn_output: + logger.info("[Ulysses] Returning split img/txt attention outputs to avoid post-attention concat/slice copies") + if self.compile_before_attn_requested and not self.compile_before_attn: + logger.warning( + "[Compile] compile_dit_before_attn is ignored unless " + "compile_dit_before_attn_unsafe is also set; before-attn graphs carry " + "block-specific weight objects and can grow Dynamo caches quickly." + ) + if self.compile_non_attn or self.compile_before_attn: + try: + torch._dynamo.config.suppress_errors = True + except Exception as exc: + logger.warning(f"[Compile] Unable to enable Dynamo suppress_errors: {exc}") + logger.info(f"[Compile] Hunyuan DiT branch compile: after_attn={self.compile_non_attn}, before_attn={self.compile_before_attn}, mode={self.compile_non_attn_mode}") def set_scheduler(self, scheduler): self.scheduler = scheduler @@ -153,8 +192,33 @@ def infer_double_block(self, weights, infer_module_out): txt = self._infer_txt_branch_after_attn(weights, txt_attn, infer_module_out.txt, txt_branch_out) return img, txt + def _apply_qkv(self, q_weight, k_weight, v_weight, hidden_states): + use_shared_quant = ( + self.share_qkv_act_quant + and hasattr(q_weight, "prepare_quantized_input") + and all(hasattr(weight, "apply_quantized_input") for weight in (q_weight, k_weight, v_weight)) + and not any(getattr(weight, "use_bf16_fallback", False) for weight in (q_weight, k_weight, v_weight)) + ) + if use_shared_quant: + quantized_input = q_weight.prepare_quantized_input(hidden_states) + return ( + q_weight.apply_quantized_input(hidden_states, quantized_input), + k_weight.apply_quantized_input(hidden_states, quantized_input), + v_weight.apply_quantized_input(hidden_states, quantized_input), + ) + return q_weight.apply(hidden_states), k_weight.apply(hidden_states), v_weight.apply(hidden_states) + @torch.no_grad() def _infer_img_branch_before_attn(self, weights, infer_module_out): + return self._run_non_attn_branch( + "img_before_attn", + self._infer_img_branch_before_attn_eager, + weights, + infer_module_out, + compile_enabled=self.compile_before_attn, + ) + + def _infer_img_branch_before_attn_eager(self, weights, infer_module_out): ( img_mod1_shift, img_mod1_scale, @@ -165,9 +229,12 @@ def _infer_img_branch_before_attn(self, weights, infer_module_out): ) = weights.img_branch.img_mod.apply(infer_module_out.vec).chunk(6, dim=-1) img_modulated = weights.img_branch.img_norm1.apply(infer_module_out.img.squeeze(0)) img_modulated = self.modulate_func(img_modulated, scale=img_mod1_scale, shift=img_mod1_shift).squeeze(0) - img_q = weights.img_branch.img_attn_q.apply(img_modulated) - img_k = weights.img_branch.img_attn_k.apply(img_modulated) - img_v = weights.img_branch.img_attn_v.apply(img_modulated) + img_q, img_k, img_v = self._apply_qkv( + weights.img_branch.img_attn_q, + weights.img_branch.img_attn_k, + weights.img_branch.img_attn_v, + img_modulated, + ) img_q = rearrange(img_q, "L (H D) -> L H D", H=self.heads_num) img_k = rearrange(img_k, "L (H D) -> L H D", H=self.heads_num) img_v = rearrange(img_v, "L (H D) -> L H D", H=self.heads_num) @@ -188,6 +255,15 @@ def _infer_img_branch_before_attn(self, weights, infer_module_out): @torch.no_grad() def _infer_txt_branch_before_attn(self, weights, infer_module_out): + return self._run_non_attn_branch( + "txt_before_attn", + self._infer_txt_branch_before_attn_eager, + weights, + infer_module_out, + compile_enabled=self.compile_before_attn, + ) + + def _infer_txt_branch_before_attn_eager(self, weights, infer_module_out): ( txt_mod1_shift, txt_mod1_scale, @@ -198,9 +274,12 @@ def _infer_txt_branch_before_attn(self, weights, infer_module_out): ) = weights.txt_branch.txt_mod.apply(infer_module_out.vec).chunk(6, dim=-1) txt_modulated = weights.txt_branch.txt_norm1.apply(infer_module_out.txt.squeeze(0)) txt_modulated = self.modulate_func(txt_modulated, scale=txt_mod1_scale, shift=txt_mod1_shift).squeeze(0) - txt_q = weights.txt_branch.txt_attn_q.apply(txt_modulated) - txt_k = weights.txt_branch.txt_attn_k.apply(txt_modulated) - txt_v = weights.txt_branch.txt_attn_v.apply(txt_modulated) + txt_q, txt_k, txt_v = self._apply_qkv( + weights.txt_branch.txt_attn_q, + weights.txt_branch.txt_attn_k, + weights.txt_branch.txt_attn_v, + txt_modulated, + ) txt_q = rearrange(txt_q, "L (H D) -> L H D", H=self.heads_num) txt_k = rearrange(txt_k, "L (H D) -> L H D", H=self.heads_num) txt_v = rearrange(txt_v, "L (H D) -> L H D", H=self.heads_num) @@ -221,33 +300,89 @@ def _infer_txt_branch_before_attn(self, weights, infer_module_out): @torch.no_grad() def _infer_attn(self, weights, img_q, img_k, img_v, txt_q, txt_k, txt_v): img_seqlen = img_q.shape[1] - query = torch.cat([img_q, txt_q], dim=1) - key = torch.cat([img_k, txt_k], dim=1) - value = torch.cat([img_v, txt_v], dim=1) - seqlen = query.shape[1] + txt_seqlen = txt_q.shape[1] + seqlen = img_seqlen + txt_seqlen cu_seqlens_qkv = torch.tensor([0, seqlen], dtype=torch.int32, device="cpu") - if self.config["seq_parallel"]: + if self.config["seq_parallel"] and self.seq_p_split_qkv_input: attn_out = weights.self_attention_parallel.apply( - q=query, - k=key, - v=value, + q=(img_q, txt_q), + k=(img_k, txt_k), + v=(img_v, txt_v), slice_qkv_len=img_seqlen, cu_seqlens_qkv=cu_seqlens_qkv, attention_module=weights.self_attention, seq_p_group=self.seq_p_group, use_fp8_comm=self.seq_p_fp8_comm, use_fp4_comm=self.seq_p_fp4_comm, + use_tensor_fusion=self.seq_p_tensor_fusion, enable_head_parallel=self.enable_head_parallel, + return_split_output=self.seq_p_split_attn_output, + async_text_gather=self.seq_p_async_text_gather, + reuse_text_gather_buffers=self.seq_p_reuse_text_gather_buffers, + text_gather_buffer_cache_max=self.seq_p_text_gather_buffer_cache_max, ) else: - attn_out = weights.self_attention.apply(q=query, k=key, v=value, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=seqlen, max_seqlen_kv=seqlen) - - img_attn, txt_attn = attn_out[:img_seqlen], attn_out[img_seqlen:] + query = torch.cat([img_q, txt_q], dim=1) + key = torch.cat([img_k, txt_k], dim=1) + value = torch.cat([img_v, txt_v], dim=1) + if self.config["seq_parallel"]: + attn_out = weights.self_attention_parallel.apply( + q=query, + k=key, + v=value, + slice_qkv_len=img_seqlen, + cu_seqlens_qkv=cu_seqlens_qkv, + attention_module=weights.self_attention, + seq_p_group=self.seq_p_group, + use_fp8_comm=self.seq_p_fp8_comm, + use_fp4_comm=self.seq_p_fp4_comm, + use_tensor_fusion=self.seq_p_tensor_fusion, + enable_head_parallel=self.enable_head_parallel, + return_split_output=self.seq_p_split_attn_output, + async_text_gather=self.seq_p_async_text_gather, + reuse_text_gather_buffers=self.seq_p_reuse_text_gather_buffers, + text_gather_buffer_cache_max=self.seq_p_text_gather_buffer_cache_max, + ) + else: + attn_out = weights.self_attention.apply(q=query, k=key, v=value, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=seqlen, max_seqlen_kv=seqlen) + + if isinstance(attn_out, (tuple, list)): + img_attn, txt_attn = attn_out + else: + img_attn, txt_attn = attn_out[:img_seqlen], attn_out[img_seqlen:] return img_attn, txt_attn + def _run_non_attn_branch(self, graph_name, eager_fn, *args, compile_enabled=None): + if compile_enabled is None: + compile_enabled = self.compile_non_attn + if not compile_enabled or graph_name in self._compile_non_attn_failed: + return eager_fn(*args) + + compiled_fn = self._compiled_non_attn.get(graph_name) + if compiled_fn is None: + try: + compiled_fn = torch.compile(eager_fn, fullgraph=False, dynamic=False, mode=self.compile_non_attn_mode) + self._compiled_non_attn[graph_name] = compiled_fn + logger.info(f"[Compile] Created compiled wrapper for {graph_name}") + except Exception as exc: + self._compile_non_attn_failed.add(graph_name) + logger.warning(f"[Compile] Failed to create compiled wrapper for {graph_name}, fallback to eager: {exc}") + return eager_fn(*args) + + try: + return compiled_fn(*args) + except Exception as exc: + self._compile_non_attn_failed.add(graph_name) + self._compiled_non_attn.pop(graph_name, None) + logger.warning(f"[Compile] Runtime failure in {graph_name}, disabling this graph and falling back to eager: {exc}") + return eager_fn(*args) + @torch.no_grad() def _infer_img_branch_after_attn(self, weights, img_attn, img, img_branch_out): + return self._run_non_attn_branch("img_after_attn", self._infer_img_branch_after_attn_eager, weights, img_attn, img, img_branch_out) + + def _infer_img_branch_after_attn_eager(self, weights, img_attn, img, img_branch_out): img = img + apply_gate(weights.img_branch.img_attn_proj.apply(img_attn).unsqueeze(0), gate=img_branch_out.img_mod1_gate) out = weights.img_branch.img_mlp_fc1.apply( self.modulate_func(weights.img_branch.img_norm2.apply(img.squeeze(0)), scale=img_branch_out.img_mod2_scale, shift=img_branch_out.img_mod2_shift).squeeze(0) @@ -258,6 +393,9 @@ def _infer_img_branch_after_attn(self, weights, img_attn, img, img_branch_out): @torch.no_grad() def _infer_txt_branch_after_attn(self, weights, txt_attn, txt, txt_branch_out): + return self._run_non_attn_branch("txt_after_attn", self._infer_txt_branch_after_attn_eager, weights, txt_attn, txt, txt_branch_out) + + def _infer_txt_branch_after_attn_eager(self, weights, txt_attn, txt, txt_branch_out): txt = txt + apply_gate(weights.txt_branch.txt_attn_proj.apply(txt_attn).unsqueeze(0), gate=txt_branch_out.txt_mod1_gate) out = weights.txt_branch.txt_mlp_fc1.apply( self.modulate_func(weights.txt_branch.txt_norm2.apply(txt.squeeze(0)), scale=txt_branch_out.txt_mod2_scale, shift=txt_branch_out.txt_mod2_shift).squeeze(0)