|
21 | 21 | import jax |
22 | 22 | from jax.ad_checkpoint import checkpoint_name |
23 | 23 | from jax.experimental import layout |
| 24 | +from jax.sharding import PartitionSpec as P |
| 25 | +from jax.experimental import shard_map |
24 | 26 | import jax.numpy as jnp |
25 | 27 | from jax.sharding import Mesh, NamedSharding |
26 | 28 |
|
@@ -623,7 +625,11 @@ def __init__( |
623 | 625 | ) |
624 | 626 |
|
625 | 627 | # Module attribute names must match names previously passed to Linen for checkpointing |
626 | | - self.MlaKVCache_0 = self.init_mla_kv_caches(inputs_kv_shape) if model_mode != MODEL_MODE_TRAIN else None |
| 628 | + self.MlaKVCache_0 = ( |
| 629 | + self.init_mla_kv_caches(inputs_kv_shape) |
| 630 | + if model_mode != MODEL_MODE_TRAIN and config.attention != "vllm_rpa" |
| 631 | + else None |
| 632 | + ) |
627 | 633 |
|
628 | 634 | def _init_projections(self, inputs_q_shape: Tuple, inputs_kv_shape: Tuple) -> None: |
629 | 635 | """Initializes the MLA-specific projections.""" |
@@ -941,15 +947,118 @@ def mla_kv_projection(self, inputs: Array, inputs_positions: Array, decoder_segm |
941 | 947 |
|
942 | 948 | key, value = self.mla_get_key_value(low_rank_main, key_rope, model_mode) |
943 | 949 | cached_values = [None, None] |
944 | | - if self.config.attention != "paged" and model_mode != MODEL_MODE_TRAIN: |
| 950 | + if self.config.attention != "paged" and self.config.attention != "vllm_rpa" and model_mode != MODEL_MODE_TRAIN: |
945 | 951 | if self.config.mla_naive_kvcache: |
946 | 952 | cached_values = self.update_kv_caches(key, value, decoder_segment_ids, model_mode, previous_chunk) |
947 | 953 | else: |
948 | 954 | cached_values = self.update_mla_kv_caches( |
949 | 955 | low_rank_main, key_rope, decoder_segment_ids, model_mode, previous_chunk |
950 | 956 | ) |
951 | 957 |
|
952 | | - return key, value, cached_values |
| 958 | + return key, value, cached_values, low_rank_main, key_rope |
| 959 | + |
| 960 | + def mla_rpa_vllm(self, q_nope, q_rope, k_latent, k_rope, mla_kv_cache, mla_metadata): |
| 961 | + """Forward function for vLLM serving with MLA attention. |
| 962 | +
|
| 963 | + Args: |
| 964 | + q_nope: Query nope part [T, N, qk_nope_head_dim] |
| 965 | + q_rope: Query rope part [T, N, qk_rope_head_dim] |
| 966 | + k_latent: Latent KV representation [S, kv_lora_rank] (NOT expanded k_nope) |
| 967 | + k_rope: Key rope part [S, qk_rope_head_dim] (NO head dimension) |
| 968 | + mla_kv_cache: The KV cache |
| 969 | + mla_metadata: Attention metadata |
| 970 | + """ |
| 971 | + md = mla_metadata |
| 972 | + try: |
| 973 | + # pylint: disable=import-outside-toplevel |
| 974 | + # pytype: disable=import-error |
| 975 | + from tpu_inference.kernels.mla.v1.kernel import mla_ragged_paged_attention |
| 976 | + from tpu_inference.kernels.ragged_paged_attention.v3.tuned_block_sizes import get_tuned_block_sizes |
| 977 | + except ImportError as e: |
| 978 | + raise ImportError( |
| 979 | + "vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`." |
| 980 | + ) from e |
| 981 | + |
| 982 | + if mla_kv_cache is None or mla_metadata is None: |
| 983 | + raise ValueError("kv_cache and attention_metadata must be provided when using vLLM.") |
| 984 | + |
| 985 | + wkv_b_kernel = self.wkv_b.kernel.value |
| 986 | + wk_b_kernel = wkv_b_kernel[..., : self.qk_nope_head_dim] |
| 987 | + wv_b_kernel = wkv_b_kernel[..., self.qk_nope_head_dim :] |
| 988 | + q_absorbed = jnp.einsum("TNH,ANH->TNA", q_nope, wk_b_kernel) |
| 989 | + |
| 990 | + def _mla_ragged_paged_attention(q, q_rope, k, k_rope, kv_cache, *args): |
| 991 | + def _initialize_block_sizes(): |
| 992 | + # Set reasonable starting estimates for block sizes. (TODO(gpolovets): update this to use tuned sizes) |
| 993 | + max_num_tokens = q_absorbed.shape[0] |
| 994 | + max_num_seqs = md.seq_lens.shape[0] |
| 995 | + num_page_indices = md.block_tables.shape[0] |
| 996 | + assert num_page_indices % max_num_seqs == 0 |
| 997 | + pages_per_seq = num_page_indices // max_num_seqs |
| 998 | + # num_kv_pages_per_block = min(pages_per_seq, 16) |
| 999 | + bkv_p, bq_sz = get_tuned_block_sizes( |
| 1000 | + q_nope.dtype, |
| 1001 | + q_nope.dtype, # changed to q_nope dtype from mla_kv_cache.dtype |
| 1002 | + self.num_query_heads, |
| 1003 | + 1, # num_kv_heads for MLA kernel |
| 1004 | + self.qk_nope_head_dim, |
| 1005 | + q_nope.shape[1], # page size ?? kv_cache.shape[1] |
| 1006 | + max_num_tokens, |
| 1007 | + pages_per_seq, |
| 1008 | + ) |
| 1009 | + num_kv_pages_per_block = min(pages_per_seq, bkv_p, 4) |
| 1010 | + num_queries_per_block = min(max_num_tokens, bq_sz, 4) # OOMS at 8 |
| 1011 | + return num_kv_pages_per_block, num_queries_per_block |
| 1012 | + |
| 1013 | + num_kv_pages_per_block, num_queries_per_block = _initialize_block_sizes() |
| 1014 | + output, kv_cache = mla_ragged_paged_attention( |
| 1015 | + q, |
| 1016 | + q_rope, |
| 1017 | + k, |
| 1018 | + k_rope, |
| 1019 | + kv_cache, |
| 1020 | + *args, |
| 1021 | + sm_scale=1.0, |
| 1022 | + num_kv_pages_per_block=num_kv_pages_per_block, |
| 1023 | + num_queries_per_block=num_queries_per_block, |
| 1024 | + ) |
| 1025 | + return kv_cache, output |
| 1026 | + |
| 1027 | + in_specs = ( |
| 1028 | + P(("attn_dp", "model", "expert"), None, None), # q |
| 1029 | + P(("attn_dp", "model", "expert"), None, None), # q_rope |
| 1030 | + P(("attn_dp", "model", "expert"), None), # k |
| 1031 | + P(("attn_dp", "model", "expert"), None), # k_rope |
| 1032 | + P(("attn_dp", "model", "expert")), # kv_cache |
| 1033 | + P(("data", "attn_dp")), # md.seq_lens: Replicated |
| 1034 | + P(("data", "attn_dp")), # page_indices_flat: Replicated |
| 1035 | + P(("data", "attn_dp")), # query_start_loc: Replicated |
| 1036 | + P(("data", "attn_dp")), # distribution: Replicated |
| 1037 | + ) |
| 1038 | + |
| 1039 | + out_specs = (P(("attn_dp", "model", "expert"), None, None), P(("attn_dp", "model", "expert"))) |
| 1040 | + |
| 1041 | + kv_cache, output = jax.jit( |
| 1042 | + shard_map.shard_map( |
| 1043 | + _mla_ragged_paged_attention, |
| 1044 | + mesh=self.mesh, |
| 1045 | + in_specs=in_specs, |
| 1046 | + out_specs=out_specs, |
| 1047 | + check_rep=False, |
| 1048 | + ), |
| 1049 | + )( |
| 1050 | + q_absorbed, |
| 1051 | + q_rope, |
| 1052 | + k_latent, |
| 1053 | + k_rope, |
| 1054 | + mla_kv_cache, |
| 1055 | + md.seq_lens, |
| 1056 | + md.block_tables, |
| 1057 | + md.query_start_loc, |
| 1058 | + md.request_distribution, |
| 1059 | + ) |
| 1060 | + output = jnp.einsum("TNA,ANH->TNH", output, wv_b_kernel) |
| 1061 | + return kv_cache, output |
953 | 1062 |
|
954 | 1063 | def __call__( |
955 | 1064 | self, |
@@ -1005,7 +1114,7 @@ def __call__( |
1005 | 1114 | query, low_rank_q = self.mla_query_projection(inputs_q, inputs_positions, model_mode) |
1006 | 1115 | if self.config.force_q_layout: |
1007 | 1116 | query = layout.with_layout_constraint(query, DLL(major_to_minor=(0, 2, 3, 1))) |
1008 | | - key, value, cached_values = self.mla_kv_projection( |
| 1117 | + key, value, cached_values, low_rank_main, key_rope = self.mla_kv_projection( |
1009 | 1118 | inputs_kv, inputs_positions, decoder_segment_ids, model_mode, previous_chunk |
1010 | 1119 | ) |
1011 | 1120 | query = checkpoint_name(query, "query_proj") |
@@ -1039,7 +1148,22 @@ def __call__( |
1039 | 1148 | ) |
1040 | 1149 | unnormalized_out = unnormalized_out[..., : self.v_head_dim] |
1041 | 1150 | out = unnormalized_out / (exp_sum + 1e-9) if exp_sum is not None else unnormalized_out |
| 1151 | + elif self.config.attention == "vllm_rpa" and model_mode != MODEL_MODE_TRAIN and kv_cache is not None: |
| 1152 | + batch, seq_len, num_heads, _ = query.shape |
| 1153 | + query = query.reshape(-1, query.shape[2], query.shape[3]) |
| 1154 | + q_nope, q_rope = jnp.split(query, [self.qk_nope_head_dim], axis=-1) |
| 1155 | + |
| 1156 | + k_latent = low_rank_main.reshape(-1, self.kv_lora_rank) |
| 1157 | + k_rope_squeezed = key_rope.reshape(-1, self.qk_rope_head_dim) |
| 1158 | + |
| 1159 | + updated_kv, attn_out = self.mla_rpa_vllm( |
| 1160 | + q_nope, q_rope, k_latent, k_rope_squeezed, mla_kv_cache=kv_cache, mla_metadata=attention_metadata |
| 1161 | + ) |
| 1162 | + out = attn_out.reshape(batch, seq_len, num_heads, self.v_head_dim) |
| 1163 | + kv_cache = updated_kv |
1042 | 1164 | else: |
| 1165 | + if self.config.attention == "vllm_rpa" and kv_cache is None and model_mode != MODEL_MODE_TRAIN: |
| 1166 | + model_mode = MODEL_MODE_TRAIN |
1043 | 1167 | out = self.attention_op( |
1044 | 1168 | query, |
1045 | 1169 | key, |
|
0 commit comments