|
34 | 34 | AxisNames, |
35 | 35 | AxisIdxes, |
36 | 36 | ATTN_LENGTH, |
37 | | - ATTN_LENGTH_NO_EXP, |
38 | 37 | DType, |
39 | 38 | Config, |
40 | 39 | Array, |
|
44 | 43 | KV_HEAD, |
45 | 44 | KV_HEAD_DIM, |
46 | 45 | KV_BATCH, |
47 | | - KV_BATCH_NO_EXP, |
48 | 46 | ATTN_EMBED, |
49 | 47 | MODEL_MODE_AUTOREGRESSIVE, |
50 | 48 | MODEL_MODE_TRAIN, |
51 | 49 | MODEL_MODE_PREFILL, |
52 | | - EP_AS_CONTEXT, |
53 | 50 | AttentionType, |
54 | 51 | ) |
55 | 52 | from maxtext.layers import nnx_wrappers |
@@ -141,14 +138,11 @@ def attention_as_linen( |
141 | 138 | prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
142 | 139 | prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
143 | 140 | prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
144 | | - query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
145 | | - key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
146 | | - value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
147 | | - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
148 | | - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
149 | | - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
150 | | - input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED), |
151 | | - out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV), |
| 141 | + query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 142 | + key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 143 | + value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 144 | + input_axis_names: AxisNames = (BATCH, ATTN_LENGTH, ATTN_EMBED), |
| 145 | + out_axis_names: AxisNames = (BATCH, ATTN_LENGTH, HEAD, D_KV), |
152 | 146 | prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED), |
153 | 147 | decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED), |
154 | 148 | prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), |
@@ -208,9 +202,6 @@ def attention_as_linen( |
208 | 202 | query_axis_names=query_axis_names, |
209 | 203 | key_axis_names=key_axis_names, |
210 | 204 | value_axis_names=value_axis_names, |
211 | | - ep_query_axis_names=ep_query_axis_names, |
212 | | - ep_key_axis_names=ep_key_axis_names, |
213 | | - ep_value_axis_names=ep_value_axis_names, |
214 | 205 | input_axis_names=input_axis_names, |
215 | 206 | out_axis_names=out_axis_names, |
216 | 207 | prefill_input_axis_names=prefill_input_axis_names, |
@@ -304,14 +295,11 @@ def __init__( |
304 | 295 | prefill_query_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
305 | 296 | prefill_key_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
306 | 297 | prefill_value_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, KV_HEAD, KV_HEAD_DIM), |
307 | | - query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
308 | | - key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
309 | | - value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH_NO_EXP, KV_HEAD, KV_HEAD_DIM), |
310 | | - ep_query_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
311 | | - ep_key_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
312 | | - ep_value_axis_names: AxisNames = (KV_BATCH_NO_EXP, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
313 | | - input_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, ATTN_EMBED), |
314 | | - out_axis_names: AxisNames = (BATCH, ATTN_LENGTH_NO_EXP, HEAD, D_KV), |
| 298 | + query_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 299 | + key_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 300 | + value_axis_names: AxisNames = (KV_BATCH, ATTN_LENGTH, KV_HEAD, KV_HEAD_DIM), |
| 301 | + input_axis_names: AxisNames = (BATCH, ATTN_LENGTH, ATTN_EMBED), |
| 302 | + out_axis_names: AxisNames = (BATCH, ATTN_LENGTH, HEAD, D_KV), |
315 | 303 | prefill_input_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, ATTN_EMBED), |
316 | 304 | decode_input_axis_names: AxisNames = (DECODE_BATCH, DECODE_LENGTH, ATTN_EMBED), |
317 | 305 | prefill_out_axis_names: AxisNames = (PREFILL_KV_BATCH, PREFILL_LENGTH, HEAD, D_KV), |
@@ -413,9 +401,6 @@ def __init__( |
413 | 401 | self.query_axis_names = query_axis_names |
414 | 402 | self.key_axis_names = key_axis_names |
415 | 403 | self.value_axis_names = value_axis_names |
416 | | - self.ep_query_axis_names = ep_query_axis_names |
417 | | - self.ep_key_axis_names = ep_key_axis_names |
418 | | - self.ep_value_axis_names = ep_value_axis_names |
419 | 404 | self.input_axis_names = input_axis_names |
420 | 405 | self.out_axis_names = out_axis_names |
421 | 406 | self.prefill_input_axis_names = prefill_input_axis_names |
@@ -1161,10 +1146,6 @@ def __call__( |
1161 | 1146 | query = self._maybe_shard_with_logical(query, (DECODE_BATCH, DECODE_LENGTH, HEAD, D_KV)) |
1162 | 1147 | key = self._maybe_shard_with_logical(key, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) |
1163 | 1148 | value = self._maybe_shard_with_logical(value, (DECODE_BATCH, DECODE_LENGTH, KV_HEAD, D_KV)) |
1164 | | - elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT: |
1165 | | - query = self._maybe_shard_with_logical(query, self.ep_query_axis_names) |
1166 | | - key = self._maybe_shard_with_logical(key, self.ep_key_axis_names) |
1167 | | - value = self._maybe_shard_with_logical(value, self.ep_value_axis_names) |
1168 | 1149 | else: |
1169 | 1150 | query = self._maybe_shard_with_logical(query, self.query_axis_names) |
1170 | 1151 | key = self._maybe_shard_with_logical(key, self.key_axis_names) |
|
0 commit comments