diff --git a/lightx2v/common/ops/attn/flash_attn.py b/lightx2v/common/ops/attn/flash_attn.py index 1c189bc9b..8a1c19bab 100755 --- a/lightx2v/common/ops/attn/flash_attn.py +++ b/lightx2v/common/ops/attn/flash_attn.py @@ -5,8 +5,8 @@ from .utils.sparge_util import block_map_ordinal_lut_triton, get_block_map_meansim try: - from flash_attn import flash_attn_func_v2 - from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2 + from flash_attn import flash_attn_func as flash_attn_func_v2 + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v2 except ImportError: logger.info("flash_attn2 not found, please install flash_attn2 first") flash_attn_func_v2 = None @@ -59,16 +59,12 @@ def apply( q = q.unsqueeze(0) k = k.unsqueeze(0) v = v.unsqueeze(0) - x = flash_attn_func_v2(q, k, v).reshape(bs * max_seqlen_q, -1) + x = flash_attn_func_v2(q, k, v).reshape(total_seqlen, -1) else: if cu_seqlens_q.is_cpu: cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True) if cu_seqlens_kv.is_cpu: cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True) - if max_seqlen_q.is_cpu: - max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) - if max_seqlen_kv.is_cpu: - max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) if len(q.shape) == 4: q = q.reshape(-1, q.shape[-2], q.shape[-1]) k = k.reshape(-1, k.shape[-2], k.shape[-1]) @@ -113,16 +109,12 @@ def apply( q = q.unsqueeze(0) k = k.unsqueeze(0) v = v.unsqueeze(0) - x = flash_attn_func_v3(q, k, v).reshape(bs * max_seqlen_q, -1) + x = flash_attn_func_v3(q, k, v).reshape(total_seqlen, -1) else: if cu_seqlens_q.is_cpu: cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True) if cu_seqlens_kv.is_cpu: cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True) - if max_seqlen_q.is_cpu: - max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) - if max_seqlen_kv.is_cpu: - max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) if len(q.shape) == 4: q = q.reshape(-1, q.shape[-2], q.shape[-1]) k = k.reshape(-1, k.shape[-2], k.shape[-1]) diff --git a/lightx2v/common/ops/attn/utils/sla_util.py b/lightx2v/common/ops/attn/utils/sla_util.py index 9177f9a68..2ed4b0c8b 100755 --- a/lightx2v/common/ops/attn/utils/sla_util.py +++ b/lightx2v/common/ops/attn/utils/sla_util.py @@ -42,6 +42,15 @@ def get_block_map(q, k, topk_ratio, BLKQ=64, BLKK=64): arg_k = k - torch.mean(k, dim=-2, keepdim=True) # smooth-k technique in SageAttention pooled_qblocks = mean_pool(q, BLKQ) pooled_kblocks = mean_pool(arg_k, BLKK) + + # GQA + num_q_heads = q.size(1) + num_kv_heads = k.size(1) + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0, f"Number of Q heads ({num_q_heads}) must be divisible by number of KV heads ({num_kv_heads})" + repeat_factor = num_q_heads // num_kv_heads + pooled_kblocks = pooled_kblocks.repeat_interleave(repeat_factor, dim=1) + pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2) K = pooled_score.shape[-1] diff --git a/lightx2v/common/ops/attn/utils/sparge_util.py b/lightx2v/common/ops/attn/utils/sparge_util.py index f7c25cf61..9aa2e777b 100644 --- a/lightx2v/common/ops/attn/utils/sparge_util.py +++ b/lightx2v/common/ops/attn/utils/sparge_util.py @@ -213,6 +213,15 @@ def get_block_map_meansim(q, k, is_causal=False, BLKQ=128, BLKK=64, simthreshd1= pooled_qblocks, sim_qblocks = get_pool_sim_triton_simmean(q, BLKQ, simthreshd1) pooled_kblocks, sim_kblocks = get_pool_sim_triton_simmean(k, BLKK, simthreshd1) + # GQA + num_q_heads = q.size(1) + num_kv_heads = k.size(1) + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0, f"Number of Q heads ({num_q_heads}) must be divisible by number of KV heads ({num_kv_heads})" + repeat_factor = num_q_heads // num_kv_heads + pooled_kblocks = pooled_kblocks.repeat_interleave(repeat_factor, dim=1) + sim_kblocks = sim_kblocks.repeat_interleave(repeat_factor, dim=1) + sim_kblocks = sim_kblocks.unsqueeze(-2).expand(-1, -1, nq, -1) # faster than repeat sim_qblocks = sim_qblocks.unsqueeze(-1).expand(-1, -1, -1, nk) pooled_score = pooled_qblocks @ pooled_kblocks.transpose(-1, -2) * q.shape[-1] ** -0.5 diff --git a/lightx2v/models/networks/bagel/infer/transformer_infer.py b/lightx2v/models/networks/bagel/infer/transformer_infer.py index 9151d2112..dfd954d1e 100644 --- a/lightx2v/models/networks/bagel/infer/transformer_infer.py +++ b/lightx2v/models/networks/bagel/infer/transformer_infer.py @@ -160,8 +160,8 @@ def self_attn( merged_value_states = packed_value_states key_values_lens = query_lens - cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)) - cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)) + cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(query_lens, dim=0), (1, 0)).to(AI_DEVICE) + cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(key_values_lens, dim=0), (1, 0)).to(AI_DEVICE) packed_attn_output = flash_attn_varlen_func( q=packed_query_states,