Skip to content

Commit 0684d35

Browse files
support sparse gqa and fix flashattn (#1039)
1 parent 0116e5d commit 0684d35

4 files changed

Lines changed: 24 additions & 14 deletions

File tree

lightx2v/common/ops/attn/flash_attn.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
from .utils.sparge_util import block_map_ordinal_lut_triton, get_block_map_meansim
66

77
try:
8-
from flash_attn import flash_attn_func_v2
9-
from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2
8+
from flash_attn import flash_attn_func as flash_attn_func_v2
9+
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v2
1010
except ImportError:
1111
logger.info("flash_attn2 not found, please install flash_attn2 first")
1212
flash_attn_func_v2 = None
@@ -59,16 +59,12 @@ def apply(
5959
q = q.unsqueeze(0)
6060
k = k.unsqueeze(0)
6161
v = v.unsqueeze(0)
62-
x = flash_attn_func_v2(q, k, v).reshape(bs * max_seqlen_q, -1)
62+
x = flash_attn_func_v2(q, k, v).reshape(total_seqlen, -1)
6363
else:
6464
if cu_seqlens_q.is_cpu:
6565
cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True)
6666
if cu_seqlens_kv.is_cpu:
6767
cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True)
68-
if max_seqlen_q.is_cpu:
69-
max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True)
70-
if max_seqlen_kv.is_cpu:
71-
max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True)
7268
if len(q.shape) == 4:
7369
q = q.reshape(-1, q.shape[-2], q.shape[-1])
7470
k = k.reshape(-1, k.shape[-2], k.shape[-1])
@@ -113,16 +109,12 @@ def apply(
113109
q = q.unsqueeze(0)
114110
k = k.unsqueeze(0)
115111
v = v.unsqueeze(0)
116-
x = flash_attn_func_v3(q, k, v).reshape(bs * max_seqlen_q, -1)
112+
x = flash_attn_func_v3(q, k, v).reshape(total_seqlen, -1)
117113
else:
118114
if cu_seqlens_q.is_cpu:
119115
cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True)
120116
if cu_seqlens_kv.is_cpu:
121117
cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True)
122-
if max_seqlen_q.is_cpu:
123-
max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True)
124-
if max_seqlen_kv.is_cpu:
125-
max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True)
126118
if len(q.shape) == 4:
127119
q = q.reshape(-1, q.shape[-2], q.shape[-1])
128120
k = k.reshape(-1, k.shape[-2], k.shape[-1])

lightx2v/common/ops/attn/utils/sla_util.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ def get_block_map(q, k, topk_ratio, BLKQ=64, BLKK=64):
4242
arg_k = k - torch.mean(k, dim=-2, keepdim=True) # smooth-k technique in SageAttention
4343
pooled_qblocks = mean_pool(q, BLKQ)
4444
pooled_kblocks = mean_pool(arg_k, BLKK)
45+
46+
# GQA
47+
num_q_heads = q.size(1)
48+
num_kv_heads = k.size(1)
49+
if num_q_heads != num_kv_heads:
50+
assert num_q_heads % num_kv_heads == 0, f"Number of Q heads ({num_q_heads}) must be divisible by number of KV heads ({num_kv_heads})"
51+
repeat_factor = num_q_heads // num_kv_heads
52+
pooled_kblocks = pooled_kblocks.repeat_interleave(repeat_factor, dim=1)
53+
4554
pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2)
4655

4756
K = pooled_score.shape[-1]

lightx2v/common/ops/attn/utils/sparge_util.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,15 @@ def get_block_map_meansim(q, k, is_causal=False, BLKQ=128, BLKK=64, simthreshd1=
213213
pooled_qblocks, sim_qblocks = get_pool_sim_triton_simmean(q, BLKQ, simthreshd1)
214214
pooled_kblocks, sim_kblocks = get_pool_sim_triton_simmean(k, BLKK, simthreshd1)
215215

216+
# GQA
217+
num_q_heads = q.size(1)
218+
num_kv_heads = k.size(1)
219+
if num_q_heads != num_kv_heads:
220+
assert num_q_heads % num_kv_heads == 0, f"Number of Q heads ({num_q_heads}) must be divisible by number of KV heads ({num_kv_heads})"
221+
repeat_factor = num_q_heads // num_kv_heads
222+
pooled_kblocks = pooled_kblocks.repeat_interleave(repeat_factor, dim=1)
223+
sim_kblocks = sim_kblocks.repeat_interleave(repeat_factor, dim=1)
224+
216225
sim_kblocks = sim_kblocks.unsqueeze(-2).expand(-1, -1, nq, -1) # faster than repeat
217226
sim_qblocks = sim_qblocks.unsqueeze(-1).expand(-1, -1, -1, nk)
218227
pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2) * q.shape[-1] ** -0.5

lightx2v/models/networks/bagel/infer/transformer_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,8 @@ def self_attn(
160160
merged_value_states = packed_value_states
161161
key_values_lens = query_lens
162162

163-
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0))
164-
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0))
163+
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)).to(AI_DEVICE)
164+
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)).to(AI_DEVICE)
165165

166166
packed_attn_output = flash_attn_varlen_func(
167167
q=packed_query_states,

0 commit comments

Comments
 (0)