|
15 | 15 | ) |
16 | 16 | from flash_attn2.bert_padding import pad_input, unpad_input |
17 | 17 | from flash_attn2.flash_attn_interface import _get_block_size_n |
18 | | -from flash_attn2.layers.rotary import apply_rotary_emb |
| 18 | + |
19 | 19 |
|
20 | 20 | MAX_HEADDIM_SM8x = 192 |
21 | 21 |
|
@@ -1955,295 +1955,9 @@ def test_flash_attn_splitkv( |
1955 | 1955 | assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4 |
1956 | 1956 |
|
1957 | 1957 |
|
1958 | | -# @pytest.mark.parametrize("dtype", ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16])) |
1959 | | -@pytest.mark.parametrize("dtype", [torch.float16]) |
1960 | | -@pytest.mark.parametrize("num_splits", [1, 0]) |
1961 | | -# @pytest.mark.parametrize("num_splits", [1]) |
1962 | | -@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) |
1963 | | -# @pytest.mark.parametrize("mha_type", ["mha"]) |
1964 | | -@pytest.mark.parametrize("new_kv", [False, True]) |
1965 | | -# @pytest.mark.parametrize("new_kv", [False]) |
1966 | | -@pytest.mark.parametrize("alibi", [False, True]) |
1967 | | -# @pytest.mark.parametrize("alibi", [False]) |
1968 | | -@pytest.mark.parametrize("local", [False, True]) |
1969 | | -# @pytest.mark.parametrize("local", [False]) |
1970 | | -@pytest.mark.parametrize("causal", [False, True]) |
1971 | | -# @pytest.mark.parametrize("causal", [False]) |
1972 | | -@pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) |
1973 | | -# @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True]) |
1974 | | -@pytest.mark.parametrize("rotary_interleaved", [False, True]) |
1975 | | -# @pytest.mark.parametrize("rotary_interleaved", [False]) |
1976 | | -@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) |
1977 | | -# @pytest.mark.parametrize("rotary_fraction", [0.0]) |
1978 | | -@pytest.mark.parametrize("paged_kv_block_size", [None, 256]) |
1979 | | -# @pytest.mark.parametrize("paged_kv_block_size", [256, 512]) |
1980 | | -# @pytest.mark.parametrize("paged_kv_block_size", [None]) |
1981 | | -@pytest.mark.parametrize("has_leftpad", [False, True]) |
1982 | | -# @pytest.mark.parametrize("has_leftpad", [True]) |
1983 | | -# @pytest.mark.parametrize("has_batch_idx", [False, True]) |
1984 | | -@pytest.mark.parametrize("has_batch_idx", [False]) |
1985 | | -@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) |
1986 | | -# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256]) |
1987 | | -# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192]) |
1988 | | -# @pytest.mark.parametrize('d', [56, 80]) |
1989 | | -# @pytest.mark.parametrize("d", [128]) |
1990 | | -@pytest.mark.parametrize( |
1991 | | - "seqlen_q,seqlen_k", |
1992 | | - [ |
1993 | | - (1, 128), |
1994 | | - (1, 339), |
1995 | | - (3, 1024), |
1996 | | - (64, 800), |
1997 | | - (64, 256), |
1998 | | - (3, 799), |
1999 | | - (64, 2048), |
2000 | | - (16, 20000), |
2001 | | - (1, 128 * 1024), |
2002 | | - (16, 128 * 1024), |
2003 | | - (128, 128), |
2004 | | - ], |
2005 | | -) |
2006 | | -# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(256, 128)]) |
2007 | | -def test_flash_attn_kvcache( |
2008 | | - seqlen_q, |
2009 | | - seqlen_k, |
2010 | | - d, |
2011 | | - has_batch_idx, |
2012 | | - has_leftpad, |
2013 | | - paged_kv_block_size, |
2014 | | - rotary_fraction, |
2015 | | - rotary_interleaved, |
2016 | | - seqlen_new_eq_seqlen_q, |
2017 | | - causal, |
2018 | | - local, |
2019 | | - alibi, |
2020 | | - new_kv, |
2021 | | - mha_type, |
2022 | | - num_splits, |
2023 | | - dtype, |
2024 | | - device, |
2025 | | -): |
2026 | | - if device == "cpu": |
2027 | | - pytest.skip("kvcache not supported on CPU") |
2028 | | - if device == "xpu": |
2029 | | - if alibi: |
2030 | | - pytest.skip("alibi not supported on xpu currently") |
2031 | | - if seqlen_q > seqlen_k and new_kv: |
2032 | | - pytest.skip() |
2033 | | - if not new_kv and rotary_fraction > 0.0: |
2034 | | - pytest.skip() |
2035 | | - if has_batch_idx and paged_kv_block_size is not None: |
2036 | | - pytest.skip() |
2037 | | - if has_leftpad and paged_kv_block_size is not None: |
2038 | | - pytest.skip() |
2039 | | - |
2040 | | - # set seed |
2041 | | - torch.random.manual_seed(0) |
2042 | | - batch_size = 2 |
2043 | | - batch_size_cache = batch_size if not has_batch_idx else batch_size * 2 |
2044 | | - nheads = 6 |
2045 | | - # rotary_dim must be a multiple of 16, and must be <= d |
2046 | | - rotary_dim = math.floor(int(rotary_fraction * d) / 16) * 16 |
2047 | | - nheads_k = nheads if mha_type == "mha" else (1 if mha_type == "mqa" else 3) |
2048 | | - assert nheads % nheads_k == 0 |
2049 | | - window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) |
2050 | | - q = torch.randn(batch_size, seqlen_q, nheads, d, device=device, dtype=dtype) |
2051 | | - seqlen_new = seqlen_q if seqlen_new_eq_seqlen_q else torch.randint(1, seqlen_q + 1, (1,)).item() |
2052 | | - if new_kv: |
2053 | | - k = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) |
2054 | | - v = torch.randn(batch_size, seqlen_new, nheads_k, d, device=device, dtype=dtype) |
2055 | | - else: |
2056 | | - k, v = None, None |
2057 | | - if paged_kv_block_size is None: |
2058 | | - k_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) |
2059 | | - v_cache = torch.randn(batch_size_cache, seqlen_k, nheads_k, d, device=device, dtype=dtype) |
2060 | | - block_table = None |
2061 | | - else: |
2062 | | - ( |
2063 | | - k_cache, |
2064 | | - v_cache, |
2065 | | - block_table, |
2066 | | - k_cache_paged, |
2067 | | - v_cache_paged, |
2068 | | - num_blocks, |
2069 | | - ) = _generate_block_kvcache( |
2070 | | - seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype |
2071 | | - ) |
2072 | | - cache_seqlens = torch.randint( |
2073 | | - 0 if new_kv else 1, |
2074 | | - # If we don't use seqlen_q in the case of causal and rotary, cos/sin won't be long enough |
2075 | | - ( |
2076 | | - (seqlen_k - (seqlen_q if (causal or local) and rotary_dim > 1 else seqlen_new) + 1) |
2077 | | - if new_kv |
2078 | | - else (seqlen_k + 1) |
2079 | | - ), |
2080 | | - (batch_size,), |
2081 | | - dtype=torch.int32, |
2082 | | - device=device, |
2083 | | - ) |
2084 | | - if has_leftpad: |
2085 | | - cache_leftpad = torch.cat([torch.randint(0, cache_seqlens[i].item(), (1,), dtype=torch.int32, device=device) |
2086 | | - if cache_seqlens[i].item() > 0 else torch.zeros(1, dtype=torch.int32, device=device) |
2087 | | - for i in range(batch_size)]) |
2088 | | - else: |
2089 | | - cache_leftpad = None |
2090 | | - arange = rearrange(torch.arange(seqlen_k, device=device), "s -> 1 s") |
2091 | | - cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") |
2092 | | - key_padding_mask = arange < cache_seqlens_expanded + (seqlen_new if new_kv else 0) |
2093 | | - if has_leftpad: |
2094 | | - key_padding_mask = torch.logical_and( |
2095 | | - key_padding_mask, arange >= cache_leftpad.unsqueeze(-1).expand(-1, seqlen_k) |
2096 | | - ) |
2097 | | - if has_batch_idx: |
2098 | | - cache_batch_idx = torch.randperm(batch_size_cache, dtype=torch.int32, device=device)[ |
2099 | | - :batch_size |
2100 | | - ] |
2101 | | - else: |
2102 | | - cache_batch_idx = None |
2103 | | - if alibi: |
2104 | | - alibi_slopes = torch.rand(batch_size, nheads, device=device, dtype=torch.float32) * 0.3 |
2105 | | - attn_bias = attn_bias_from_alibi_slopes( |
2106 | | - alibi_slopes, seqlen_q, seqlen_k, None, key_padding_mask, causal=causal, key_leftpad=cache_leftpad |
2107 | | - ) |
2108 | | - else: |
2109 | | - alibi_slopes, attn_bias = None, None |
2110 | | - # cache_seqlens = torch.tensor([64], dtype=torch.int32, device=device) |
2111 | | - if rotary_dim > 0: |
2112 | | - angle = ( |
2113 | | - torch.rand( |
2114 | | - seqlen_k if paged_kv_block_size is None else num_blocks * paged_kv_block_size, |
2115 | | - rotary_dim // 2, |
2116 | | - device=device, |
2117 | | - ) |
2118 | | - * 2 |
2119 | | - * math.pi |
2120 | | - ) |
2121 | | - cos = torch.cos(angle).to(dtype=dtype) |
2122 | | - sin = torch.sin(angle).to(dtype=dtype) |
2123 | | - if causal or local: |
2124 | | - q_ro = apply_rotary_emb( |
2125 | | - q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved |
2126 | | - ) |
2127 | | - else: |
2128 | | - q_ro = rearrange( |
2129 | | - apply_rotary_emb( |
2130 | | - rearrange(q, "b s h d -> b 1 (s h) d"), |
2131 | | - cos, |
2132 | | - sin, |
2133 | | - seqlen_offsets=cache_seqlens, |
2134 | | - interleaved=rotary_interleaved, |
2135 | | - ), |
2136 | | - "b 1 (s h) d -> b s h d", |
2137 | | - s=seqlen_q, |
2138 | | - ) |
2139 | | - # q_ro = q |
2140 | | - k_ro = apply_rotary_emb( |
2141 | | - k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved |
2142 | | - ) |
2143 | | - else: |
2144 | | - cos, sin = None, None |
2145 | | - q_ro, k_ro = q, k |
2146 | | - # k_cache[:, 64:] = -1 |
2147 | | - k_cache_ref = ( |
2148 | | - k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] |
2149 | | - ).clone() |
2150 | | - v_cache_ref = ( |
2151 | | - v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] |
2152 | | - ).clone() |
2153 | | - if new_kv: |
2154 | | - update_mask = torch.logical_and( |
2155 | | - cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new |
2156 | | - ) |
2157 | | - k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") |
2158 | | - v_cache_ref[update_mask] = rearrange(v, "b s ... -> (b s) ...") |
2159 | | - k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
2160 | | - v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=nheads // nheads_k) |
2161 | | - out = flash_attn_with_kvcache( |
2162 | | - q, |
2163 | | - k_cache if paged_kv_block_size is None else k_cache_paged, |
2164 | | - v_cache if paged_kv_block_size is None else v_cache_paged, |
2165 | | - k, |
2166 | | - v, |
2167 | | - rotary_cos=cos, |
2168 | | - rotary_sin=sin, |
2169 | | - cache_seqlens=cache_seqlens, |
2170 | | - cache_batch_idx=cache_batch_idx, |
2171 | | - cache_leftpad=cache_leftpad, |
2172 | | - block_table=block_table, |
2173 | | - causal=causal, |
2174 | | - window_size=window_size, |
2175 | | - rotary_interleaved=rotary_interleaved, |
2176 | | - alibi_slopes=alibi_slopes, |
2177 | | - num_splits=num_splits, |
2178 | | - ) |
2179 | | - # out = flash_attn_with_kvcache( |
2180 | | - # q, k_cache, v_cache, cache_seqlens=cache_seqlens, causal=causal, window_size=window_size |
2181 | | - # ) |
2182 | | - # out = flash_attn_with_kvcache(q, k_cache, v_cache, causal=causal, window_size=window_size) |
2183 | | - # qk = torch.einsum("bqhd,bkhd->bhqk", q, k_cache_ref) |
2184 | | - # m = qk.amax(-1, keepdim=True) |
2185 | | - # s_tmp = torch.exp((qk - m) / math.sqrt(d)) |
2186 | | - # o1 = torch.einsum('bhst,bthd->bshd', s_tmp, v_cache_ref) |
2187 | | - # lse_ref = torch.logsumexp(qk / math.sqrt(d), -1) |
2188 | | - # probs = torch.softmax(qk, dim=-1) |
2189 | | - out_ref, _ = attention_ref( |
2190 | | - q_ro, |
2191 | | - k_cache_rep, |
2192 | | - v_cache_rep, |
2193 | | - None, |
2194 | | - key_padding_mask, |
2195 | | - attn_bias, |
2196 | | - 0.0, |
2197 | | - None, |
2198 | | - causal=causal, |
2199 | | - window_size=window_size, |
2200 | | - key_leftpad=cache_leftpad, |
2201 | | - ) |
2202 | | - out_pt, _ = attention_ref( |
2203 | | - q_ro, |
2204 | | - k_cache_rep, |
2205 | | - v_cache_rep, |
2206 | | - None, |
2207 | | - key_padding_mask, |
2208 | | - attn_bias, |
2209 | | - 0.0, |
2210 | | - None, |
2211 | | - causal=causal, |
2212 | | - window_size=window_size, |
2213 | | - upcast=False, |
2214 | | - reorder_ops=True, |
2215 | | - key_leftpad=cache_leftpad, |
2216 | | - ) |
2217 | | - print(f"Output max diff: {(out - out_ref).abs().max().item()}") |
2218 | | - print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") |
2219 | | - print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") |
2220 | | - print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") |
2221 | 1958 |
|
2222 | | - # Check that FlashAttention's numerical error is at most twice the numerical error |
2223 | | - # of a Pytorch implementation. |
2224 | | - if new_kv: |
2225 | | - if paged_kv_block_size is None: |
2226 | | - k_cache_select = ( |
2227 | | - k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)] |
2228 | | - ) |
2229 | | - v_cache_select = ( |
2230 | | - v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)] |
2231 | | - ) |
2232 | | - else: |
2233 | | - k_cache_select = rearrange( |
2234 | | - k_cache_paged[block_table.to(dtype=torch.long).flatten()], |
2235 | | - "(b nblocks) block_size ... -> b (nblocks block_size) ...", |
2236 | | - b=batch_size, |
2237 | | - )[:, :seqlen_k] |
2238 | | - v_cache_select = rearrange( |
2239 | | - v_cache_paged[block_table.to(dtype=torch.long).flatten()], |
2240 | | - "(b nblocks) block_size ... -> b (nblocks block_size) ...", |
2241 | | - b=batch_size, |
2242 | | - )[:, :seqlen_k] |
2243 | | - assert torch.allclose(k_cache_select, k_cache_ref, rtol=1e-3, atol=1e-3) |
2244 | | - assert torch.equal(v_cache_select, v_cache_ref) |
2245 | | - mult = 3 if not alibi else 5 |
2246 | | - assert (out - out_ref).abs().max().item() <= mult * (out_pt - out_ref).abs().max().item() + 1e-5 |
| 1959 | +# NOTE: test_flash_attn_kvcache was removed because it depended on |
| 1960 | +# flash_attn2.layers.rotary (upstream baggage not shipped with this kernel). |
2247 | 1961 |
|
2248 | 1962 |
|
2249 | 1963 | def _generate_block_kvcache(seqlen_k, paged_kv_block_size, batch_size, nheads_k, d, device, dtype): |
|
0 commit comments