|
5 | 5 | from .utils.sparge_util import block_map_ordinal_lut_triton, get_block_map_meansim |
6 | 6 |
|
7 | 7 | try: |
8 | | - from flash_attn import flash_attn_func_v2 |
9 | | - from flash_attn.flash_attn_interface import flash_attn_varlen_func_v2 |
| 8 | + from flash_attn import flash_attn_func as flash_attn_func_v2 |
| 9 | + from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func_v2 |
10 | 10 | except ImportError: |
11 | 11 | logger.info("flash_attn2 not found, please install flash_attn2 first") |
12 | 12 | flash_attn_func_v2 = None |
@@ -59,16 +59,12 @@ def apply( |
59 | 59 | q = q.unsqueeze(0) |
60 | 60 | k = k.unsqueeze(0) |
61 | 61 | v = v.unsqueeze(0) |
62 | | - x = flash_attn_func_v2(q, k, v).reshape(bs * max_seqlen_q, -1) |
| 62 | + x = flash_attn_func_v2(q, k, v).reshape(total_seqlen, -1) |
63 | 63 | else: |
64 | 64 | if cu_seqlens_q.is_cpu: |
65 | 65 | cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True) |
66 | 66 | if cu_seqlens_kv.is_cpu: |
67 | 67 | cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True) |
68 | | - if max_seqlen_q.is_cpu: |
69 | | - max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) |
70 | | - if max_seqlen_kv.is_cpu: |
71 | | - max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) |
72 | 68 | if len(q.shape) == 4: |
73 | 69 | q = q.reshape(-1, q.shape[-2], q.shape[-1]) |
74 | 70 | k = k.reshape(-1, k.shape[-2], k.shape[-1]) |
@@ -113,16 +109,12 @@ def apply( |
113 | 109 | q = q.unsqueeze(0) |
114 | 110 | k = k.unsqueeze(0) |
115 | 111 | v = v.unsqueeze(0) |
116 | | - x = flash_attn_func_v3(q, k, v).reshape(bs * max_seqlen_q, -1) |
| 112 | + x = flash_attn_func_v3(q, k, v).reshape(total_seqlen, -1) |
117 | 113 | else: |
118 | 114 | if cu_seqlens_q.is_cpu: |
119 | 115 | cu_seqlens_q = cu_seqlens_q.to(q.device, non_blocking=True) |
120 | 116 | if cu_seqlens_kv.is_cpu: |
121 | 117 | cu_seqlens_kv = cu_seqlens_kv.to(k.device, non_blocking=True) |
122 | | - if max_seqlen_q.is_cpu: |
123 | | - max_seqlen_q = max_seqlen_q.to(q.device, non_blocking=True) |
124 | | - if max_seqlen_kv.is_cpu: |
125 | | - max_seqlen_kv = max_seqlen_kv.to(k.device, non_blocking=True) |
126 | 118 | if len(q.shape) == 4: |
127 | 119 | q = q.reshape(-1, q.shape[-2], q.shape[-1]) |
128 | 120 | k = k.reshape(-1, k.shape[-2], k.shape[-1]) |
|
0 commit comments