perf: optimize Hunyuan DiT Ulysses and non-attention paths#1200
perf: optimize Hunyuan DiT Ulysses and non-attention paths#1200starrkk wants to merge 4 commits into
Conversation
(cherry picked from commit 8f06fb6c7e0859f432a329a84f8d5d8e3a386ad1)
Support split image/text QKV inputs, optional split attention outputs, async text all_gather, and profiler ranges for Ulysses sequence-parallel attention. (cherry picked from commit 8bb7c3e1784140a8f6d372fe429b468e3a502b8b)
Reuse dynamic activation quantization across consecutive Q/K/V projections and route split image/text tensors through the Ulysses attention path when enabled. (cherry picked from commit 61c5df5c20106254d5294b910cdf3d1780970a97)
There was a problem hiding this comment.
Code Review
This pull request introduces several performance optimizations for Ulysses attention and Hunyuan Video transformer inference, including support for split QKV inputs/outputs to avoid copy overhead, asynchronous text gathering, buffer reuse, shared dynamic activation quantization, and optional torch.compile support for non-attention branches. The reviewer identified several critical issues: potential NCCL hangs due to overlapping collective operations on the same process group, high compilation overhead and cache thrashing from compiling functions with custom weight objects, an unbounded memory leak in the text gather buffer cache under dynamic prompt lengths, incorrect text mask length calculation when using split QKV inputs, and a potential AttributeError in the shared quantization check if key/value weights lack the expected quantization methods.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| if _ASYNC_TEXT_GATHER_ENABLED: | ||
| with _profile_range("text_all_gather_launch"): | ||
| gathered_txt_attn = self._get_text_gather_buffers(txt_attn, world_size) | ||
| text_gather_work = dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group, async_op=True) | ||
|
|
||
| img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, q_shard_heads, hidden_dims, seq_p_group, use_fp8_comm) | ||
|
|
||
| # 收集所有进程的文本注意力结果 | ||
| gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)] | ||
| dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group) | ||
| txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果 | ||
| if _ASYNC_TEXT_GATHER_ENABLED: | ||
| with _profile_range("text_all_gather_wait"): | ||
| text_gather_work.wait() | ||
| else: | ||
| with _profile_range("text_all_gather"): | ||
| gathered_txt_attn = self._get_text_gather_buffers(txt_attn, world_size) | ||
| dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group) |
There was a problem hiding this comment.
When _ASYNC_TEXT_GATHER_ENABLED is enabled, dist.all_gather is launched asynchronously. However, before calling wait() on the work handle, the code calls _reshape_img_attn, which internally invokes all2all_head2seq and launches a synchronous dist.all_to_all_single on the same process group (seq_p_group). In PyTorch distributed, launching another collective operation on the same process group while an async collective is outstanding violates the API contract and can lead to NCCL hangs, crashes, or undefined behavior. Since both operations use the same process group, they are serialized on the same NCCL stream anyway, so async gather does not provide GPU-side overlap. It is highly recommended to avoid overlapping collectives on the same process group.
| def _run_non_attn_branch(self, graph_name, eager_fn, *args, compile_enabled=None): | ||
| if compile_enabled is None: | ||
| compile_enabled = self.compile_non_attn | ||
| if not compile_enabled or graph_name in self._compile_non_attn_failed: | ||
| return eager_fn(*args) | ||
|
|
||
| compiled_fn = self._compiled_non_attn.get(graph_name) | ||
| if compiled_fn is None: | ||
| try: | ||
| compiled_fn = torch.compile(eager_fn, fullgraph=False, dynamic=False, mode=self.compile_non_attn_mode) | ||
| self._compiled_non_attn[graph_name] = compiled_fn | ||
| logger.info(f"[Compile] Created compiled wrapper for {graph_name}") | ||
| except Exception as exc: | ||
| self._compile_non_attn_failed.add(graph_name) | ||
| logger.warning(f"[Compile] Failed to create compiled wrapper for {graph_name}, fallback to eager: {exc}") | ||
| return eager_fn(*args) | ||
|
|
||
| try: | ||
| return compiled_fn(*args) | ||
| except Exception as exc: | ||
| self._compile_non_attn_failed.add(graph_name) | ||
| self._compiled_non_attn.pop(graph_name, None) | ||
| logger.warning(f"[Compile] Runtime failure in {graph_name}, disabling this graph and falling back to eager: {exc}") | ||
| return eager_fn(*args) |
There was a problem hiding this comment.
Using torch.compile on functions like _infer_img_branch_before_attn_eager that accept weights (which is a different custom Python object for each of the double_blocks) will trigger a graph recompilation for every single block. This is because Dynamo guards on the identity and attributes of the custom weights object. This leads to extremely high warm-up latency and will quickly exceed the default Dynamo cache size limit, causing it to fall back to eager execution. To make torch.compile effective, the compiled functions should only take Tensors as inputs (not custom weight objects), or the compilation should be applied at a higher level.
| def _get_text_gather_buffers(self, tensor, world_size): | ||
| if not _REUSE_TEXT_GATHER_BUFFERS_ENABLED: | ||
| return [torch.empty_like(tensor) for _ in range(world_size)] | ||
|
|
||
| key = (world_size, tuple(tensor.shape), tensor.dtype, tensor.device) | ||
| buffers = self._text_gather_buffers.get(key) | ||
| if buffers is None: | ||
| buffers = [torch.empty_like(tensor) for _ in range(world_size)] | ||
| self._text_gather_buffers[key] = buffers | ||
| return buffers |
There was a problem hiding this comment.
The self._text_gather_buffers dictionary caches empty tensors for text gather using tuple(tensor.shape) as part of the key. In a serving environment with dynamic prompt lengths, the text sequence length varies per request. This causes the dictionary to grow indefinitely, leading to a memory leak and eventual OOM. It is recommended to limit the cache size or clear it when it exceeds a threshold.
def _get_text_gather_buffers(self, tensor, world_size):
if not _REUSE_TEXT_GATHER_BUFFERS_ENABLED:
return [torch.empty_like(tensor) for _ in range(world_size)]
key = (world_size, tuple(tensor.shape), tensor.dtype, tensor.device)
buffers = self._text_gather_buffers.get(key)
if buffers is None:
if len(self._text_gather_buffers) >= 16:
self._text_gather_buffers.clear()
buffers = [torch.empty_like(tensor) for _ in range(world_size)]
self._text_gather_buffers[key] = buffers
return buffers| if split_qkv_input: | ||
| img_qkv_len = img_q.shape[0] | ||
| txt_qkv_len = txt_q.shape[0] | ||
| txt_mask_len = None |
There was a problem hiding this comment.
When split_qkv_input is True, txt_mask_len is unconditionally set to None. However, if len(cu_seqlens_qkv) == 3, there is a text mask, and txt_mask_len should be computed as cu_seqlens_qkv[2] - slice_qkv_len to ensure correctness for other models or configurations using this generic Ulysses attention implementation.
| if split_qkv_input: | |
| img_qkv_len = img_q.shape[0] | |
| txt_qkv_len = txt_q.shape[0] | |
| txt_mask_len = None | |
| if split_qkv_input: | |
| img_qkv_len = img_q.shape[0] | |
| txt_qkv_len = txt_q.shape[0] | |
| txt_mask_len = cu_seqlens_qkv[2] - slice_qkv_len if len(cu_seqlens_qkv) == 3 else None |
| def _apply_qkv(self, q_weight, k_weight, v_weight, hidden_states): | ||
| use_shared_quant = ( | ||
| self.share_qkv_act_quant | ||
| and hasattr(q_weight, "prepare_quantized_input") | ||
| and hasattr(q_weight, "apply_quantized_input") | ||
| and not any(getattr(weight, "use_bf16_fallback", False) for weight in (q_weight, k_weight, v_weight)) | ||
| ) |
There was a problem hiding this comment.
The use_shared_quant check only verifies hasattr(q_weight, "prepare_quantized_input") and hasattr(q_weight, "apply_quantized_input"). It does not verify if k_weight and v_weight also have the apply_quantized_input method. If they don't (e.g., in some mixed-precision or custom GQA setups), it will raise an AttributeError at runtime. It is safer to check all three weights.
def _apply_qkv(self, q_weight, k_weight, v_weight, hidden_states):
use_shared_quant = (
self.share_qkv_act_quant
and hasattr(q_weight, "prepare_quantized_input")
and hasattr(q_weight, "apply_quantized_input")
and hasattr(k_weight, "apply_quantized_input")
and hasattr(v_weight, "apply_quantized_input")
and not any(getattr(weight, "use_bf16_fallback", False) for weight in (q_weight, k_weight, v_weight))
)
Summary
torch.compilewrappersWhy
These switches reduce Python/tensor layout overhead around HunyuanVideo DiT inference and let Hygon DCU deployments reuse activation quantization for consecutive Q/K/V projections. Defaults remain opt-in through environment variables.
Validation
ModelTC/LightX2V:main(89dfa833)git diff --checkpassed for the PR branch