Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 78 additions & 8 deletions lightx2v/common/ops/attn/ulysses_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,21 @@
from .template import AttnWeightTemplate
from .utils.all2all import all2all_head2seq


def _is_split_qkv_input(tensor_or_pair):
return isinstance(tensor_or_pair, (tuple, list))


def _to_3d_qkv(tensor):
if len(tensor.shape) == 4:
return tensor.reshape(-1, tensor.shape[-2], tensor.shape[-1])
return tensor


def _contiguous_if_needed(tensor):
return tensor if tensor.is_contiguous() else tensor.contiguous()


try:
from sageattn3_sparse import dequant_fp4 as dequant_fp4_sage3
from sageattn3_sparse import quant_fp4 as quant_fp4_sage3
Expand All @@ -21,6 +36,20 @@
class UlyssesAttnWeight(AttnWeightTemplate):
def __init__(self):
self.config = {}
self._text_gather_buffers = {}

def _get_text_gather_buffers(self, tensor, world_size, reuse_buffers=False, cache_max=16):
if not reuse_buffers:
return [torch.empty_like(tensor) for _ in range(world_size)]

key = (world_size, tuple(tensor.shape), tensor.dtype, tensor.device)
buffers = self._text_gather_buffers.get(key)
if buffers is None:
if len(self._text_gather_buffers) >= max(1, cache_max):
self._text_gather_buffers.clear()
buffers = [torch.empty_like(tensor) for _ in range(world_size)]
self._text_gather_buffers[key] = buffers
return buffers

def apply(
self,
Expand All @@ -37,6 +66,10 @@ def apply(
enable_head_parallel=False,
img_first=True,
q_only_img=False,
return_split_output=False,
async_text_gather=False,
reuse_text_gather_buffers=False,
text_gather_buffer_cache_max=16,
**kwargs,
):
"""
Expand All @@ -62,8 +95,16 @@ def apply(
assert not (use_fp8_comm and use_fp4_comm), "use_fp8_comm and use_fp4_comm can't be enabled at the same time."

use_qkv_fusion = use_tensor_fusion
split_qkv_input = _is_split_qkv_input(q)

if len(q.shape) == 4:
if split_qkv_input:
if q_only_img:
raise NotImplementedError("split QKV input does not support q_only_img yet.")
assert _is_split_qkv_input(k) and _is_split_qkv_input(v), "q/k/v must all use split img/txt input."
img_q, txt_q = (_to_3d_qkv(tensor) for tensor in q)
img_k, txt_k = (_to_3d_qkv(tensor) for tensor in k)
img_v, txt_v = (_to_3d_qkv(tensor) for tensor in v)
elif len(q.shape) == 4:
q = q.reshape(-1, q.shape[-2], q.shape[-1])
k = k.reshape(-1, k.shape[-2], k.shape[-1])
v = v.reshape(-1, v.shape[-2], v.shape[-1])
Expand All @@ -73,7 +114,11 @@ def apply(
cur_rank = dist.get_rank(seq_p_group)

# 获取序列长度和文本相关的长度
if img_first:
if split_qkv_input:
img_qkv_len = img_q.shape[0]
txt_qkv_len = txt_q.shape[0]
txt_mask_len = cu_seqlens_qkv[2] - img_qkv_len if img_first and len(cu_seqlens_qkv) == 3 else None
elif img_first:
img_qkv_len = slice_qkv_len
if len(cu_seqlens_qkv) == 3:
txt_qkv_len = cu_seqlens_qkv[1] - slice_qkv_len # 文本查询、键和值的长度
Expand All @@ -87,8 +132,12 @@ def apply(
txt_mask_len = None

# 分别获取 q 和 kv 的头数,支持 GQA(k/v 头数可能少于 q)
_, q_heads, hidden_dims = q.shape
_, kv_heads, _ = k.shape
if split_qkv_input:
_, q_heads, hidden_dims = img_q.shape
_, kv_heads, _ = img_k.shape
else:
_, q_heads, hidden_dims = q.shape
_, kv_heads, _ = k.shape
is_gqa = q_heads != kv_heads
q_shard_heads = q_heads // world_size # q 每个进程处理的头数
kv_shard_heads = kv_heads // world_size # k/v 每个进程处理的头数
Expand Down Expand Up @@ -118,7 +167,14 @@ def apply(
max_seqlen_q = max_seqlen_kv

# 分割图像和文本的查询、键和值
if q_only_img:
if split_qkv_input:
img_q = _contiguous_if_needed(img_q)
img_k = _contiguous_if_needed(img_k)
img_v = _contiguous_if_needed(img_v)
txt_q = _contiguous_if_needed(txt_q)
txt_k = _contiguous_if_needed(txt_k)
txt_v = _contiguous_if_needed(txt_v)
elif q_only_img:
# q 只含图像 token,无需分割;仅 k/v 需要拆出图像和文本部分
img_q = q.contiguous()
txt_q = None
Expand Down Expand Up @@ -462,13 +518,27 @@ def apply(
txt_attn, img_attn = attn[:txt_qkv_len, :], attn[txt_qkv_len:]

# 通信所有进程的图像注意力结果
gathered_txt_attn = None
text_gather_work = None
if async_text_gather:
gathered_txt_attn = self._get_text_gather_buffers(txt_attn, world_size, reuse_text_gather_buffers, text_gather_buffer_cache_max)
text_gather_work = dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group, async_op=True)

# Finish the async gather before launching any later collective on the same process group.
if async_text_gather:
text_gather_work.wait()

img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, q_shard_heads, hidden_dims, seq_p_group, use_fp8_comm)

# 收集所有进程的文本注意力结果
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
# Gather text attention synchronously when async launch is disabled.
if not async_text_gather:
gathered_txt_attn = self._get_text_gather_buffers(txt_attn, world_size, reuse_text_gather_buffers, text_gather_buffer_cache_max)
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果

if return_split_output:
return img_attn, txt_attn

# 合并图像和文本的注意力结果
if img_first:
attn = torch.cat([img_attn, txt_attn], dim=0)
Expand Down
Loading