Skip to content

perf: optimize Hunyuan DiT Ulysses and non-attention paths#1200

Draft
starrkk wants to merge 4 commits into
ModelTC:mainfrom
starrkk:codex/hunyuan-dit-ulysses-optimizations
Draft

perf: optimize Hunyuan DiT Ulysses and non-attention paths#1200
starrkk wants to merge 4 commits into
ModelTC:mainfrom
starrkk:codex/hunyuan-dit-ulysses-optimizations

Conversation

@starrkk

@starrkk starrkk commented Jun 30, 2026

Copy link
Copy Markdown

Summary

  • add optional Hunyuan DiT non-attention torch.compile wrappers
  • add split QKV input/output support for Ulysses attention
  • add async text gather and profile ranges for Ulysses attention
  • wire shared QKV activation quantization into the Hunyuan transformer path

Why

These switches reduce Python/tensor layout overhead around HunyuanVideo DiT inference and let Hygon DCU deployments reuse activation quantization for consecutive Q/K/V projections. Defaults remain opt-in through environment variables.

Validation

  • branch rebuilt on latest ModelTC/LightX2V:main (89dfa833)
  • git diff --check passed for the PR branch
  • validated as part of the HunyuanVideo1.5 I2V 8-card benchmark path on Hygon DCU

zhenggf added 3 commits June 30, 2026 11:50
(cherry picked from commit 8f06fb6c7e0859f432a329a84f8d5d8e3a386ad1)
Support split image/text QKV inputs, optional split attention outputs, async text all_gather, and profiler ranges for Ulysses sequence-parallel attention.

(cherry picked from commit 8bb7c3e1784140a8f6d372fe429b468e3a502b8b)
Reuse dynamic activation quantization across consecutive Q/K/V projections and route split image/text tensors through the Ulysses attention path when enabled.

(cherry picked from commit 61c5df5c20106254d5294b910cdf3d1780970a97)

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several performance optimizations for Ulysses attention and Hunyuan Video transformer inference, including support for split QKV inputs/outputs to avoid copy overhead, asynchronous text gathering, buffer reuse, shared dynamic activation quantization, and optional torch.compile support for non-attention branches. The reviewer identified several critical issues: potential NCCL hangs due to overlapping collective operations on the same process group, high compilation overhead and cache thrashing from compiling functions with custom weight objects, an unbounded memory leak in the text gather buffer cache under dynamic prompt lengths, incorrect text mask length calculation when using split QKV inputs, and a potential AttributeError in the shared quantization check if key/value weights lack the expected quantization methods.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +540 to +554
if _ASYNC_TEXT_GATHER_ENABLED:
with _profile_range("text_all_gather_launch"):
gathered_txt_attn = self._get_text_gather_buffers(txt_attn, world_size)
text_gather_work = dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group, async_op=True)

img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, q_shard_heads, hidden_dims, seq_p_group, use_fp8_comm)

# 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
if _ASYNC_TEXT_GATHER_ENABLED:
with _profile_range("text_all_gather_wait"):
text_gather_work.wait()
else:
with _profile_range("text_all_gather"):
gathered_txt_attn = self._get_text_gather_buffers(txt_attn, world_size)
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When _ASYNC_TEXT_GATHER_ENABLED is enabled, dist.all_gather is launched asynchronously. However, before calling wait() on the work handle, the code calls _reshape_img_attn, which internally invokes all2all_head2seq and launches a synchronous dist.all_to_all_single on the same process group (seq_p_group). In PyTorch distributed, launching another collective operation on the same process group while an async collective is outstanding violates the API contract and can lead to NCCL hangs, crashes, or undefined behavior. Since both operations use the same process group, they are serialized on the same NCCL stream anyway, so async gather does not provide GPU-side overlap. It is highly recommended to avoid overlapping collectives on the same process group.

Comment on lines +345 to +368
def _run_non_attn_branch(self, graph_name, eager_fn, *args, compile_enabled=None):
if compile_enabled is None:
compile_enabled = self.compile_non_attn
if not compile_enabled or graph_name in self._compile_non_attn_failed:
return eager_fn(*args)

compiled_fn = self._compiled_non_attn.get(graph_name)
if compiled_fn is None:
try:
compiled_fn = torch.compile(eager_fn, fullgraph=False, dynamic=False, mode=self.compile_non_attn_mode)
self._compiled_non_attn[graph_name] = compiled_fn
logger.info(f"[Compile] Created compiled wrapper for {graph_name}")
except Exception as exc:
self._compile_non_attn_failed.add(graph_name)
logger.warning(f"[Compile] Failed to create compiled wrapper for {graph_name}, fallback to eager: {exc}")
return eager_fn(*args)

try:
return compiled_fn(*args)
except Exception as exc:
self._compile_non_attn_failed.add(graph_name)
self._compiled_non_attn.pop(graph_name, None)
logger.warning(f"[Compile] Runtime failure in {graph_name}, disabling this graph and falling back to eager: {exc}")
return eager_fn(*args)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using torch.compile on functions like _infer_img_branch_before_attn_eager that accept weights (which is a different custom Python object for each of the double_blocks) will trigger a graph recompilation for every single block. This is because Dynamo guards on the identity and attributes of the custom weights object. This leads to extremely high warm-up latency and will quickly exceed the default Dynamo cache size limit, causing it to fall back to eager execution. To make torch.compile effective, the compiled functions should only take Tensors as inputs (not custom weight objects), or the compilation should be applied at a higher level.

Comment on lines +57 to +66
def _get_text_gather_buffers(self, tensor, world_size):
if not _REUSE_TEXT_GATHER_BUFFERS_ENABLED:
return [torch.empty_like(tensor) for _ in range(world_size)]

key = (world_size, tuple(tensor.shape), tensor.dtype, tensor.device)
buffers = self._text_gather_buffers.get(key)
if buffers is None:
buffers = [torch.empty_like(tensor) for _ in range(world_size)]
self._text_gather_buffers[key] = buffers
return buffers

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The self._text_gather_buffers dictionary caches empty tensors for text gather using tuple(tensor.shape) as part of the key. In a serving environment with dynamic prompt lengths, the text sequence length varies per request. This causes the dictionary to grow indefinitely, leading to a memory leak and eventual OOM. It is recommended to limit the cache size or clear it when it exceeds a threshold.

    def _get_text_gather_buffers(self, tensor, world_size):
        if not _REUSE_TEXT_GATHER_BUFFERS_ENABLED:
            return [torch.empty_like(tensor) for _ in range(world_size)]

        key = (world_size, tuple(tensor.shape), tensor.dtype, tensor.device)
        buffers = self._text_gather_buffers.get(key)
        if buffers is None:
            if len(self._text_gather_buffers) >= 16:
                self._text_gather_buffers.clear()
            buffers = [torch.empty_like(tensor) for _ in range(world_size)]
            self._text_gather_buffers[key] = buffers
        return buffers

Comment on lines +128 to +131
if split_qkv_input:
img_qkv_len = img_q.shape[0]
txt_qkv_len = txt_q.shape[0]
txt_mask_len = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

When split_qkv_input is True, txt_mask_len is unconditionally set to None. However, if len(cu_seqlens_qkv) == 3, there is a text mask, and txt_mask_len should be computed as cu_seqlens_qkv[2] - slice_qkv_len to ensure correctness for other models or configurations using this generic Ulysses attention implementation.

Suggested change
if split_qkv_input:
img_qkv_len = img_q.shape[0]
txt_qkv_len = txt_q.shape[0]
txt_mask_len = None
if split_qkv_input:
img_qkv_len = img_q.shape[0]
txt_qkv_len = txt_q.shape[0]
txt_mask_len = cu_seqlens_qkv[2] - slice_qkv_len if len(cu_seqlens_qkv) == 3 else None

Comment on lines +190 to +196
def _apply_qkv(self, q_weight, k_weight, v_weight, hidden_states):
use_shared_quant = (
self.share_qkv_act_quant
and hasattr(q_weight, "prepare_quantized_input")
and hasattr(q_weight, "apply_quantized_input")
and not any(getattr(weight, "use_bf16_fallback", False) for weight in (q_weight, k_weight, v_weight))
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use_shared_quant check only verifies hasattr(q_weight, "prepare_quantized_input") and hasattr(q_weight, "apply_quantized_input"). It does not verify if k_weight and v_weight also have the apply_quantized_input method. If they don't (e.g., in some mixed-precision or custom GQA setups), it will raise an AttributeError at runtime. It is safer to check all three weights.

    def _apply_qkv(self, q_weight, k_weight, v_weight, hidden_states):
        use_shared_quant = (
            self.share_qkv_act_quant
            and hasattr(q_weight, "prepare_quantized_input")
            and hasattr(q_weight, "apply_quantized_input")
            and hasattr(k_weight, "apply_quantized_input")
            and hasattr(v_weight, "apply_quantized_input")
            and not any(getattr(weight, "use_bf16_fallback", False) for weight in (q_weight, k_weight, v_weight))
        )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant